diff --git a/rock/actions/sandbox/response.py b/rock/actions/sandbox/response.py index 1e44d26017..283373d0f5 100644 --- a/rock/actions/sandbox/response.py +++ b/rock/actions/sandbox/response.py @@ -16,6 +16,8 @@ class State(str, Enum): PENDING = "pending" RUNNING = "running" STOPPED = "stopped" + ARCHIVING = "archiving" + ARCHIVED = "archived" DELETED = "deleted" diff --git a/rock/actions/sandbox/sandbox_info.py b/rock/actions/sandbox/sandbox_info.py index 33ee9ab1d5..1c40a076bc 100644 --- a/rock/actions/sandbox/sandbox_info.py +++ b/rock/actions/sandbox/sandbox_info.py @@ -26,6 +26,8 @@ class SandboxInfo(TypedDict, total=False): start_time: str stop_time: str delete_time: str + archive_time: str + state_enter_time: str extended_params: dict[str, str] diff --git a/rock/admin/core/schema.py b/rock/admin/core/schema.py index d9e22c5579..02176b39b6 100644 --- a/rock/admin/core/schema.py +++ b/rock/admin/core/schema.py @@ -46,6 +46,9 @@ class SandboxRecord(Base): cpus = Column(Float, nullable=True) memory = Column(String(64), nullable=True) create_user_gray_flag = Column(Boolean, nullable=True) + archive_time = Column(String(64), nullable=True) + state_enter_time = Column(String(64), nullable=True) + delete_time = Column(String(64), nullable=True) phases = Column(_JSONB_VARIANT, nullable=True) port_mapping = Column(_JSONB_VARIANT, nullable=True) spec = Column(_JSONB_VARIANT, nullable=True) diff --git a/rock/admin/entrypoints/sandbox_api.py b/rock/admin/entrypoints/sandbox_api.py index f83fc0ec7f..8c1427bcb8 100644 --- a/rock/admin/entrypoints/sandbox_api.py +++ b/rock/admin/entrypoints/sandbox_api.py @@ -1,7 +1,7 @@ import math from typing import Annotated, Any -from fastapi import APIRouter, Body, Depends, File, Form, UploadFile +from fastapi import APIRouter, Body, Depends, File, Form, Header, UploadFile from rock.actions import ( BashObservation, @@ -117,7 +117,7 @@ async def _apply_accelerator_type_validation(config: DockerDeploymentConfig) -> if config.accelerator_type not in allowed: raise BadRequestRockError( - f"Invalid accelerator_type {config.accelerator_type!r}. " f"Allowed values: {sorted(allowed)}" + f"Invalid accelerator_type {config.accelerator_type!r}. Allowed values: {sorted(allowed)}" ) @@ -158,10 +158,18 @@ async def _apply_cpu_overcommit_default(config: DockerDeploymentConfig, rock_aut config.limit_cpus = min(2 * config.cpus, config.cpus + headroom) +def _apply_auto_clear_default(config: DockerDeploymentConfig, request: SandboxStartRequest) -> None: + """Use lifecycle.auto_clear_default_sec when SDK did not explicitly set auto_clear_time_minutes.""" + if "auto_clear_time_minutes" not in request.model_fields_set: + default_sec = sandbox_manager.rock_config.lifecycle.auto_clear_default_sec + config.auto_clear_time_minutes = default_sec // 60 + + @sandbox_router.post("/start") @handle_exceptions(error_message="start sandbox failed") async def start(request: SandboxStartRequest) -> RockResponse[SandboxStartResponse]: config = DockerDeploymentConfig.from_request(request) + _apply_auto_clear_default(config, request) await _apply_accelerator_type_validation(config) await _apply_kata_runtime_switch(config) await _apply_kata_disk_size(config) @@ -177,6 +185,7 @@ async def start_async( headers: Annotated[StartHeaders, Depends()], ) -> RockResponse[SandboxStartResponse]: config = DockerDeploymentConfig.from_request(request) + _apply_auto_clear_default(config, request) await _apply_accelerator_type_validation(config) await _apply_kata_runtime_switch(config) await _apply_kata_disk_size(config) @@ -296,6 +305,19 @@ async def delete(sandbox_id: str = Body(..., embed=True)) -> RockResponse: return RockResponse(result=f"{sandbox_id} deleted") +@sandbox_router.post("/archive") +@handle_exceptions(error_message="archive sandbox failed") +async def archive( + sandbox_id: str = Body(..., embed=True), + rock_authorization: str | None = Header(default=None, alias="X-Key"), +) -> RockResponse: + archive_cfg = sandbox_manager.rock_config.lifecycle.archive + if not archive_cfg.is_allowed(rock_authorization): + raise BadRequestRockError("archive not allowed for this authorization key") + await sandbox_manager.archive_sandbox(sandbox_id) + return RockResponse(result=f"{sandbox_id} archiving") + + @sandbox_router.post("/restart") @handle_exceptions(error_message="restart sandbox failed") async def restart(sandbox_id: str = Body(..., embed=True)) -> RockResponse[SandboxStartResponse]: diff --git a/rock/admin/main.py b/rock/admin/main.py index a3c7537f40..9f81702a42 100644 --- a/rock/admin/main.py +++ b/rock/admin/main.py @@ -38,8 +38,11 @@ ) from rock.admin.service.ops_service import OpsService from rock.common.exception import request_validation_exception_handler -from rock.config import DatabaseConfig, RockConfig, SchedulerConfig +from rock.config import DatabaseConfig, RockConfig, SandboxLifecycleConfig, SchedulerConfig from rock.logger import init_logger +from rock.sandbox.archive.oss_storage import OssDirStorage +from rock.sandbox.archive.registry_v2 import DockerRegistryV2ImageStorage +from rock.sandbox.archive.s3_storage import S3DirStorage from rock.sandbox.gem_manager import GemManager from rock.sandbox.operator.factory import OperatorContext, OperatorFactory from rock.sandbox.sandbox_meta_store import SandboxMetaStore @@ -93,12 +96,16 @@ async def lifespan(app: FastAPI): ) rock_config = RockConfig.from_env(config_file_path) - # Override scheduler config from Nacos if available + # Override configs from Nacos if available if rock_config.nacos_provider: nacos_config = await rock_config.nacos_provider.get_config() - if nacos_config and "scheduler" in nacos_config: - rock_config.scheduler = SchedulerConfig(**nacos_config["scheduler"]) - logger.info(f"Overrode scheduler config from Nacos with {len(rock_config.scheduler.tasks)} tasks") + if nacos_config: + if "scheduler" in nacos_config: + rock_config.scheduler = SchedulerConfig(**nacos_config["scheduler"]) + logger.info(f"Overrode scheduler config from Nacos with {len(rock_config.scheduler.tasks)} tasks") + if "lifecycle" in nacos_config: + rock_config.lifecycle = SandboxLifecycleConfig(**nacos_config["lifecycle"]) + logger.info(f"Overrode lifecycle config from Nacos: {rock_config.lifecycle}") env_vars.ROCK_ADMIN_ENV = args.env env_vars.ROCK_ADMIN_ROLE = args.role @@ -179,6 +186,37 @@ async def lifespan(app: FastAPI): meta_store=meta_store, ) set_sandbox_manager(sandbox_manager) + + archive_cfg = rock_config.lifecycle.archive + if archive_cfg.enabled: + acr = archive_cfg.acr + ds = archive_cfg.dir_storage + acr_ready = acr.registry_url and acr.username and acr.password + ds_ready = ds.endpoint and ds.access_key_id and ds.access_key_secret + if not (acr_ready and ds_ready): + raise RuntimeError("archive.enabled=true but ACR or dir_storage credentials are missing") + if ds.type == "s3": + sandbox_manager._dir_storage = S3DirStorage( + endpoint=ds.endpoint, + bucket=ds.bucket, + access_key_id=ds.access_key_id, + access_key_secret=ds.access_key_secret, + region=ds.region, + ) + else: + sandbox_manager._dir_storage = OssDirStorage( + endpoint=ds.endpoint, + bucket=ds.bucket, + access_key_id=ds.access_key_id, + access_key_secret=ds.access_key_secret, + region=ds.region, + ) + sandbox_manager._image_storage = DockerRegistryV2ImageStorage( + registry_url=acr.registry_url, + username=acr.username, + password=acr.password, + ) + warmup_service = WarmupService(rock_config.warmup) await warmup_service.init() set_warmup_service(warmup_service) diff --git a/rock/config.py b/rock/config.py index e2b2b7d8f7..a458d1593d 100644 --- a/rock/config.py +++ b/rock/config.py @@ -153,6 +153,66 @@ def __post_init__(self): self.primary = OssAccountConfig(**self.primary) +@dataclass +class ArchiveDirStorageConfig: + type: str = "oss" + endpoint: str = "" + bucket: str = "" + access_key_id: str = "" + access_key_secret: str = "" + region: str = "" + + +@dataclass +class ArchiveAcrConfig: + registry_url: str = "" + username: str = "" + password: str = "" + namespace: str = "sandbox_archive" + + +@dataclass +class ArchiveConfig: + enabled: bool = False + allowed_keys: list[str] | None = None + dir_storage: ArchiveDirStorageConfig = field(default_factory=ArchiveDirStorageConfig) + acr: ArchiveAcrConfig = field(default_factory=ArchiveAcrConfig) + max_retries: int = 3 + prefix: str = "rock-archives/" + max_image_push_size: str = "16g" + max_dir_upload_size: str = "16g" + + def is_allowed(self, rock_authorization: str | None) -> bool: + """Check if the caller is allowed to use archive. None = open to all.""" + if self.allowed_keys is None: + return True + if not isinstance(self.allowed_keys, list): + return False + return rock_authorization in self.allowed_keys + + def __post_init__(self): + if isinstance(self.dir_storage, dict): + self.dir_storage = ArchiveDirStorageConfig(**self.dir_storage) + if isinstance(self.acr, dict): + self.acr = ArchiveAcrConfig(**self.acr) + + +@dataclass +class SandboxLifecycleConfig: + auto_transition_interval_sec: int = 180 + reconcile_interval_sec: int = 30 + archive_timeout_sec: int = 1800 + restore_timeout_sec: int = 1800 + auto_archive_after_sec: int = 0 + auto_delete_after_sec: int = 0 + auto_clear_default_sec: int = 21600 + archive: ArchiveConfig = field(default_factory=ArchiveConfig) + + def __post_init__(self): + if isinstance(self.archive, dict): + self.archive = ArchiveConfig(**self.archive) + + @dataclass class ProxyServiceConfig: timeout: float = 180.0 @@ -333,6 +393,7 @@ class RockConfig: redis: RedisConfig = field(default_factory=RedisConfig) sandbox_config: SandboxConfig = field(default_factory=SandboxConfig) oss: OssConfig = field(default_factory=OssConfig) + lifecycle: SandboxLifecycleConfig = field(default_factory=SandboxLifecycleConfig) runtime: RuntimeConfig = field(default_factory=RuntimeConfig) proxy_service: ProxyServiceConfig = field(default_factory=ProxyServiceConfig) scheduler: SchedulerConfig = field(default_factory=SchedulerConfig) @@ -384,6 +445,8 @@ def from_env(cls, config_path: str | None = None): kwargs["sandbox_config"] = SandboxConfig(**config["sandbox_config"]) if "oss" in config: kwargs["oss"] = OssConfig(**config["oss"]) + if "lifecycle" in config: + kwargs["lifecycle"] = SandboxLifecycleConfig(**config["lifecycle"]) if "runtime" in config: kwargs["runtime"] = RuntimeConfig(**config["runtime"]) if "proxy_service" in config: @@ -482,6 +545,7 @@ async def update(self): config_map = { "sandbox_config": (SandboxConfig, "sandbox_config"), "proxy_service": (ProxyServiceConfig, "proxy_service"), + "lifecycle": (SandboxLifecycleConfig, "lifecycle"), } # Update configs that are present in nacos_result @@ -490,5 +554,7 @@ async def update(self): setattr(self, attr_name, config_class(**nacos_result[key])) logger.info( - f"Updated config from Nacos: sandbox_config={self.sandbox_config}, proxy_service={self.proxy_service}" + f"Updated config from Nacos: sandbox_config={self.sandbox_config}," + f" proxy_service={self.proxy_service}," + f" lifecycle={self.lifecycle}" ) diff --git a/rock/deployments/docker.py b/rock/deployments/docker.py index f10cabf971..cc0bb8762f 100644 --- a/rock/deployments/docker.py +++ b/rock/deployments/docker.py @@ -727,6 +727,118 @@ async def restart(self): logger.info(f"Container {self._container_name} restarted successfully") + def _container_exists(self) -> bool: + result = subprocess.run( + ["docker", "inspect", self._container_name], + capture_output=True, + timeout=5, + ) + return result.returncode == 0 + + async def restart_from_image(self, image_ref: str): + """Restart a container, recreating it from a committed image if it no longer exists. + + Used by the archive-restore flow: the image was previously committed + and pushed to a registry. If the original container still exists, + behaves like restart(). If not, creates a new container from + *image_ref* with the same config (ports, memory, cpus, volumes). + """ + if self._container_exists(): + await self.restart() + return + + logger.info(f"Container {self._container_name} not found, recreating from {image_ref}") + + executor = get_executor() + loop = asyncio.get_running_loop() + + # Recover port mappings from persisted file, or allocate new ones. + status_path = PersistedServiceStatus.gen_service_status_path(self._container_name) + self._service_status.set_sandbox_id(self._container_name) + if os.path.exists(status_path): + with open(status_path) as f: + data = json.load(f) + for port_value, mapping in data.get("port_mapping", {}).items(): + self._service_status.add_port_mapping(int(port_value), mapping) + if self._config.port is None: + self._config.port = self._service_status.port_mapping.get(Port.PROXY) + + if self._config.port is None: + await self.do_port_mapping() + + # Build docker create command from archived image. + env_arg = ["-e", f"ROCK_TIME_ZONE={env_vars.ROCK_TIME_ZONE}"] + volume_args = self._prepare_volume_mounts() + if env_vars.ROCK_LOGGING_PATH: + log_file_path = f"{env_vars.ROCK_LOGGING_PATH}/{self._container_name}" + os.makedirs(log_file_path, exist_ok=True) + os.chmod(log_file_path, 0o777) + volume_args.extend(["-v", f"{log_file_path}:{env_vars.ROCK_LOGGING_PATH}"]) + env_arg.extend( + [ + "-e", + f"ROCK_LOGGING_PATH={env_vars.ROCK_LOGGING_PATH}", + "-e", + f"ROCK_LOGGING_LEVEL={env_vars.ROCK_LOGGING_LEVEL}", + ] + ) + volume_args.extend(self._prepare_timezone_mount()) + runtime_args = self._build_runtime_args() + + cmds = [ + "docker", + "create", + "--entrypoint", + "", + *env_arg, + *volume_args, + *runtime_args, + "-p", + f"{self._config.port}:{Port.PROXY}", + "-p", + f"{self._service_status.get_mapped_port(Port.SERVER)}:8080", + "-p", + f"{self._service_status.get_mapped_port(Port.SSH)}:22", + *self._memory(), + *self._cpus(), + *self._storage_opts(), + *self._config.docker_args, + "--name", + self._container_name, + image_ref, + *self._get_rocklet_start_cmd(), + ] + + cmd_str = shlex.join(cmds) + logger.info(f"Recreating container {self._container_name} from archived image {image_ref}") + logger.info(f"Command: {cmd_str!r}") + + await loop.run_in_executor(executor, self._docker_create, cmds) + try: + self._container_process = await loop.run_in_executor(executor, self._docker_start) + except Exception: + DockerUtil.remove_container_force(self._container_name) + raise + + if self._config.port is None: + self._config.port = self._service_status.port_mapping.get(Port.PROXY) + if self._config.port is None: + raise Exception(f"Cannot determine rocklet port for container {self._container_name}") + + logger.info(f"Starting runtime at {self._config.port}") + self._runtime = RemoteSandboxRuntime.from_config( + RemoteSandboxRuntimeConfig(port=self._config.port, timeout=self._runtime_timeout) + ) + self._runtime.set_executor(executor) + + with StageTimer("startup_timing", f"[{self._container_name}] Wait until alive", logger): + await self._wait_until_alive(timeout=self._config.startup_timeout) + + if self._config.enable_auto_clear: + self._check_stop_task = asyncio.create_task(self._check_stop()) + + logger.info(f"Container {self._container_name} recreated from {image_ref} successfully") + async def delete(self) -> None: """Remove the container via ``docker rm -f``. diff --git a/rock/sandbox/archive/__init__.py b/rock/sandbox/archive/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/rock/sandbox/archive/_constants.py b/rock/sandbox/archive/_constants.py new file mode 100644 index 0000000000..0de34603d0 --- /dev/null +++ b/rock/sandbox/archive/_constants.py @@ -0,0 +1,6 @@ +def dir_archive_key(sandbox_id: str, prefix: str) -> str: + return f"{prefix}{sandbox_id}.tar.gz" + + +def image_ref(sandbox_id: str, registry_url: str, namespace: str) -> str: + return f"{registry_url}/{namespace}/sandbox_archived:{sandbox_id}" diff --git a/rock/sandbox/archive/abstract.py b/rock/sandbox/archive/abstract.py new file mode 100644 index 0000000000..cf39a620e6 --- /dev/null +++ b/rock/sandbox/archive/abstract.py @@ -0,0 +1,57 @@ +from abc import ABC, abstractmethod + + +class AbstractDirStorage(ABC): + """Directory-level archive storage. + + Caller passes a local directory path; the storage implementation + decides compression format and upload strategy internally. + """ + + @abstractmethod + async def upload_dir(self, local_dir: str, key: str) -> None: + """Pack and upload local_dir to key. Raises FileNotFoundError if local_dir missing.""" + + @abstractmethod + async def download_to_dir(self, key: str, local_dir: str) -> None: + """Download key and restore to local_dir. + + Raises FileNotFoundError if key missing, FileExistsError if local_dir exists. + """ + + @abstractmethod + async def delete(self, key: str) -> bool: + """Delete key. Returns False if not found.""" + + @abstractmethod + async def exists(self, key: str) -> bool: + """Check if key exists.""" + + +class AbstractImageStorage(ABC): + """Container image archive storage. + + push_from_local / pull_to_local must run on the node owning the image + (node-local docker daemon). exists / delete are registry HTTP API calls. + """ + + @property + @abstractmethod + def registry_url(self) -> str: + """Registry endpoint (no scheme), e.g. 'localhost:5000'.""" + + @abstractmethod + async def push_from_local(self, local_image_tag: str, remote_image_ref: str) -> None: + """Tag local image and push to registry.""" + + @abstractmethod + async def pull_to_local(self, remote_image_ref: str) -> None: + """Pull image from registry to local daemon. Raises RuntimeError if not found.""" + + @abstractmethod + async def delete(self, image_ref: str) -> bool: + """Delete manifest from registry via V2 HTTP API. Returns False if 404.""" + + @abstractmethod + async def exists(self, image_ref: str) -> bool: + """HEAD manifest via V2 HTTP API.""" diff --git a/rock/sandbox/archive/oss_storage.py b/rock/sandbox/archive/oss_storage.py new file mode 100644 index 0000000000..d3c2de0a1f --- /dev/null +++ b/rock/sandbox/archive/oss_storage.py @@ -0,0 +1,128 @@ +import asyncio +import os +import shutil +import tempfile + +import oss2 + +from rock.sandbox.archive.abstract import AbstractDirStorage + + +class OssDirStorage(AbstractDirStorage): + """OSS Access Point directory storage using oss2 SDK + AuthV4.""" + + def __init__( + self, + endpoint: str, + bucket: str, + access_key_id: str, + access_key_secret: str, + region: str = "cn-hangzhou", + ): + self._endpoint = endpoint + self._bucket = bucket + self._access_key_id = access_key_id + self._access_key_secret = access_key_secret + self._region = region + + @property + def client_config(self) -> dict: + return { + "type": "oss", + "endpoint": self._endpoint, + "bucket": self._bucket, + "access_key_id": self._access_key_id, + "access_key_secret": self._access_key_secret, + "region": self._region, + } + + def _make_bucket(self) -> oss2.Bucket: + auth = oss2.AuthV4(self._access_key_id, self._access_key_secret) + return oss2.Bucket(auth, self._endpoint, self._bucket, region=self._region) + + async def upload_dir(self, local_dir: str, key: str) -> None: + if not os.path.isdir(local_dir): + raise FileNotFoundError(f"Directory not found: {local_dir}") + + parent = os.path.dirname(os.path.abspath(local_dir)) + basename = os.path.basename(os.path.abspath(local_dir)) + + tmp_tar = None + try: + with tempfile.NamedTemporaryFile(suffix=".tar.gz", delete=False) as f: + tmp_tar = f.name + + proc = await asyncio.create_subprocess_exec( + "tar", + "-czf", + tmp_tar, + "-C", + parent, + basename, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + _, stderr = await proc.communicate() + if proc.returncode != 0: + raise RuntimeError(f"tar failed: {stderr.decode()}") + + bucket = self._make_bucket() + await asyncio.to_thread(oss2.resumable_upload, bucket, key, tmp_tar) + finally: + if tmp_tar and os.path.exists(tmp_tar): + os.unlink(tmp_tar) + + async def download_to_dir(self, key: str, local_dir: str) -> None: + if os.path.exists(local_dir): + raise FileExistsError(f"Target directory already exists: {local_dir}") + + parent = os.path.dirname(os.path.abspath(local_dir)) + basename = os.path.basename(os.path.abspath(local_dir)) + os.makedirs(parent, exist_ok=True) + + tmp_tar = None + try: + with tempfile.NamedTemporaryFile(suffix=".tar.gz", delete=False) as f: + tmp_tar = f.name + + bucket = self._make_bucket() + try: + await asyncio.to_thread(bucket.get_object_to_file, key, tmp_tar) + except oss2.exceptions.NoSuchKey: + raise FileNotFoundError(f"Key not found: {key}") + + proc = await asyncio.create_subprocess_exec( + "tar", + "-xzf", + tmp_tar, + "-C", + parent, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + _, stderr = await proc.communicate() + if proc.returncode != 0: + raise RuntimeError(f"tar extract failed: {stderr.decode()}") + + extracted = os.path.join(parent, basename) + if not os.path.isdir(extracted): + entries = os.listdir(parent) + raise RuntimeError(f"Expected directory '{basename}' after extraction, got: {entries}") + except Exception: + if os.path.exists(local_dir): + shutil.rmtree(local_dir, ignore_errors=True) + raise + finally: + if tmp_tar and os.path.exists(tmp_tar): + os.unlink(tmp_tar) + + async def delete(self, key: str) -> bool: + bucket = self._make_bucket() + if not await asyncio.to_thread(bucket.object_exists, key): + return False + await asyncio.to_thread(bucket.delete_object, key) + return True + + async def exists(self, key: str) -> bool: + bucket = self._make_bucket() + return await asyncio.to_thread(bucket.object_exists, key) diff --git a/rock/sandbox/archive/registry_v2.py b/rock/sandbox/archive/registry_v2.py new file mode 100644 index 0000000000..71b717438c --- /dev/null +++ b/rock/sandbox/archive/registry_v2.py @@ -0,0 +1,184 @@ +import asyncio +import base64 +import os +import re +import shutil +import tempfile +from urllib.parse import urlparse + +import httpx + +from rock.sandbox.archive.abstract import AbstractImageStorage + + +class DockerRegistryV2ImageStorage(AbstractImageStorage): + """Docker Registry V2 image storage using docker CLI + HTTP API. + + Requires username/password authentication. + """ + + def __init__(self, registry_url: str, username: str, password: str): + self._registry_url = registry_url + self._username = username + self._password = password + + @property + def registry_url(self) -> str: + return self._registry_url + + @property + def client_config(self) -> dict: + return { + "registry_url": self._registry_url, + "username": self._username, + "password": self._password, + } + + async def push_from_local(self, local_image_tag: str, remote_image_ref: str) -> None: + async with self._docker_auth() as env: + await self._docker_cmd("docker", "tag", local_image_tag, remote_image_ref) + try: + await self._docker_cmd("docker", "push", remote_image_ref, env=env) + finally: + await self._docker_cmd("docker", "rmi", remote_image_ref, check=False) + + async def pull_to_local(self, remote_image_ref: str) -> None: + async with self._docker_auth() as env: + await self._docker_cmd("docker", "pull", remote_image_ref, env=env) + + async def delete(self, image_ref: str) -> bool: + registry, name, tag = self._parse_ref(image_ref) + base_url = await self._registry_base_url(registry) + accept = "application/vnd.docker.distribution.manifest.v2+json, application/vnd.oci.image.manifest.v1+json" + async with httpx.AsyncClient() as client: + auth_headers = await self._resolve_auth_headers(client, base_url, name) + if not auth_headers: + return False + + resp = await client.get( + f"{base_url}/v2/{name}/manifests/{tag}", + headers={**auth_headers, "Accept": accept}, + ) + if resp.status_code != 200: + return False + digest = resp.headers.get("Docker-Content-Digest") + if not digest: + return False + + resp = await client.delete(f"{base_url}/v2/{name}/manifests/{digest}", headers=auth_headers) + return resp.status_code in (200, 202) + + async def _resolve_auth_headers(self, client: httpx.AsyncClient, base_url: str, repo_name: str) -> dict | None: + """Negotiate auth with the registry and return headers for subsequent requests.""" + challenge = await client.get(f"{base_url}/v2/") + www_auth = challenge.headers.get("Www-Authenticate", "") + + if www_auth.lower().startswith("basic"): + creds = base64.b64encode(f"{self._username}:{self._password}".encode()).decode() + return {"Authorization": f"Basic {creds}"} + + m = re.search(r'realm="([^"]+)"', www_auth) + realm = m.group(1) if m else None + m = re.search(r'service="([^"]+)"', www_auth) + service = m.group(1) if m else None + if not realm or not service: + return None + + token_resp = await client.get( + realm, + params={"service": service, "scope": f"repository:{repo_name}:pull,push,delete"}, + auth=(self._username, self._password), + ) + if token_resp.status_code != 200: + return None + token = token_resp.json().get("token") + if not token: + return None + return {"Authorization": f"Bearer {token}"} + + async def exists(self, image_ref: str) -> bool: + async with self._docker_auth() as env: + try: + await self._docker_cmd("docker", "manifest", "inspect", "--insecure", image_ref, env=env) + return True + except RuntimeError: + return False + + def _docker_auth(self): + return _DockerAuthContext(self) + + @staticmethod + async def _registry_base_url(registry: str) -> str: + """Return the base URL (with scheme) for a registry host:port.""" + async with httpx.AsyncClient() as client: + try: + resp = await client.get(f"https://{registry}/v2/", timeout=3) + if resp.status_code in (200, 401): + return f"https://{registry}" + except (httpx.ConnectError, httpx.ConnectTimeout): + pass + return f"http://{registry}" + + @staticmethod + def _parse_ref(image_ref: str) -> tuple[str, str, str]: + """Parse 'registry/repo/name:tag' into (registry, name, tag).""" + if "://" in image_ref: + parsed = urlparse(image_ref) + image_ref = parsed.netloc + parsed.path + + parts = image_ref.split("/", 1) + if len(parts) < 2: + raise ValueError(f"Invalid image ref: {image_ref}") + registry = parts[0] + name_tag = parts[1] + + if ":" in name_tag.split("/")[-1]: + last_colon = name_tag.rfind(":") + name = name_tag[:last_colon] + tag = name_tag[last_colon + 1 :] + else: + name = name_tag + tag = "latest" + + return registry, name, tag + + @staticmethod + async def _docker_cmd(*args: str, check: bool = True, env: dict | None = None, stdin_data: bytes | None = None): + proc = await asyncio.create_subprocess_exec( + *args, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + stdin=asyncio.subprocess.PIPE if stdin_data else None, + env=env, + ) + stdout, stderr = await proc.communicate(input=stdin_data) + if check and proc.returncode != 0: + raise RuntimeError(f"{' '.join(args)} failed (rc={proc.returncode}): {stderr.decode()}") + + +class _DockerAuthContext: + """Context manager that creates isolated DOCKER_CONFIG and performs docker login.""" + + def __init__(self, storage: DockerRegistryV2ImageStorage): + self._storage = storage + self._tmpdir = None + + async def __aenter__(self) -> dict: + self._tmpdir = tempfile.mkdtemp() + env = os.environ.copy() + env["DOCKER_CONFIG"] = self._tmpdir + await DockerRegistryV2ImageStorage._docker_cmd( + "docker", + "login", + self._storage._registry_url, + "--username", + self._storage._username, + "--password-stdin", + env=env, + stdin_data=self._storage._password.encode(), + ) + return env + + async def __aexit__(self, *args): + if self._tmpdir: + shutil.rmtree(self._tmpdir, ignore_errors=True) diff --git a/rock/sandbox/archive/s3_storage.py b/rock/sandbox/archive/s3_storage.py new file mode 100644 index 0000000000..25a44f5879 --- /dev/null +++ b/rock/sandbox/archive/s3_storage.py @@ -0,0 +1,142 @@ +import asyncio +import os +import shutil +import tempfile + +import boto3 +from botocore.exceptions import ClientError + +from rock.sandbox.archive.abstract import AbstractDirStorage + + +class S3DirStorage(AbstractDirStorage): + """S3-compatible directory storage (works with MinIO and OSS S3 endpoint).""" + + def __init__( + self, + endpoint: str, + bucket: str, + access_key_id: str, + access_key_secret: str, + region: str = "us-east-1", + ): + self._endpoint = endpoint + self._bucket = bucket + self._access_key_id = access_key_id + self._access_key_secret = access_key_secret + self._region = region + + @property + def client_config(self) -> dict: + return { + "type": "s3", + "endpoint": self._endpoint, + "bucket": self._bucket, + "access_key_id": self._access_key_id, + "access_key_secret": self._access_key_secret, + "region": self._region, + } + + def _make_client(self): + return boto3.client( + "s3", + endpoint_url=self._endpoint, + aws_access_key_id=self._access_key_id, + aws_secret_access_key=self._access_key_secret, + region_name=self._region, + ) + + async def upload_dir(self, local_dir: str, key: str) -> None: + if not os.path.isdir(local_dir): + raise FileNotFoundError(f"Directory not found: {local_dir}") + + parent = os.path.dirname(os.path.abspath(local_dir)) + basename = os.path.basename(os.path.abspath(local_dir)) + + tmp_tar = None + try: + with tempfile.NamedTemporaryFile(suffix=".tar.gz", delete=False) as f: + tmp_tar = f.name + + proc = await asyncio.create_subprocess_exec( + "tar", + "-czf", + tmp_tar, + "-C", + parent, + basename, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + _, stderr = await proc.communicate() + if proc.returncode != 0: + raise RuntimeError(f"tar failed: {stderr.decode()}") + + client = self._make_client() + await asyncio.to_thread(client.upload_file, tmp_tar, self._bucket, key) + finally: + if tmp_tar and os.path.exists(tmp_tar): + os.unlink(tmp_tar) + + async def download_to_dir(self, key: str, local_dir: str) -> None: + if os.path.exists(local_dir): + raise FileExistsError(f"Target directory already exists: {local_dir}") + + parent = os.path.dirname(os.path.abspath(local_dir)) + basename = os.path.basename(os.path.abspath(local_dir)) + os.makedirs(parent, exist_ok=True) + + tmp_tar = None + try: + with tempfile.NamedTemporaryFile(suffix=".tar.gz", delete=False) as f: + tmp_tar = f.name + + client = self._make_client() + try: + await asyncio.to_thread(client.download_file, self._bucket, key, tmp_tar) + except ClientError as e: + if e.response["Error"]["Code"] == "404": + raise FileNotFoundError(f"Key not found: {key}") + raise + + proc = await asyncio.create_subprocess_exec( + "tar", + "-xzf", + tmp_tar, + "-C", + parent, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + _, stderr = await proc.communicate() + if proc.returncode != 0: + raise RuntimeError(f"tar extract failed: {stderr.decode()}") + + extracted = os.path.join(parent, basename) + if not os.path.isdir(extracted): + entries = os.listdir(parent) + raise RuntimeError(f"Expected directory '{basename}' after extraction, got: {entries}") + except Exception: + if os.path.exists(local_dir): + shutil.rmtree(local_dir, ignore_errors=True) + raise + finally: + if tmp_tar and os.path.exists(tmp_tar): + os.unlink(tmp_tar) + + async def delete(self, key: str) -> bool: + if not await self.exists(key): + return False + client = self._make_client() + await asyncio.to_thread(client.delete_object, Bucket=self._bucket, Key=key) + return True + + async def exists(self, key: str) -> bool: + client = self._make_client() + try: + await asyncio.to_thread(client.head_object, Bucket=self._bucket, Key=key) + return True + except ClientError as e: + if e.response["Error"]["Code"] == "404": + return False + raise diff --git a/rock/sandbox/base_manager.py b/rock/sandbox/base_manager.py index c89e731aef..522410e323 100644 --- a/rock/sandbox/base_manager.py +++ b/rock/sandbox/base_manager.py @@ -35,7 +35,8 @@ def __init__( user_defined_tags=rock_config.runtime.user_defined_tags, ) self._report_interval = 10 - self._check_job_interval = 180 + self._check_job_interval = rock_config.lifecycle.auto_transition_interval_sec + self._reconcile_interval = rock_config.lifecycle.reconcile_interval_sec self._setup_scheduler() self.deployment_manager = DeploymentManager(rock_config, enable_runtime_auto_clear) @@ -61,18 +62,23 @@ def _setup_metrics_scheduler(self): logger.info("APScheduler started for metrics collection") def _setup_job_check_scheduler(self): - """Set up scheduler""" self.scheduler = AsyncIOScheduler( timezone="UTC", job_defaults={"coalesce": True, "max_instances": 1, "misfire_grace_time": 30} ) self.scheduler.add_job( - func=self._check_job_background, + func=self._auto_transition, trigger=IntervalTrigger(seconds=self._check_job_interval), - id="job_check", - name="Sandbox Job Check", + id="auto_transition", + name="Sandbox Auto Transition", + ) + self.scheduler.add_job( + func=self._reconcile, + trigger=IntervalTrigger(seconds=self._reconcile_interval), + id="reconcile", + name="Sandbox Reconcile", ) self.scheduler.start() - logger.info("APScheduler started for job check") + logger.info("APScheduler started for auto_transition and reconcile") async def _collect_and_report_metrics(self): start_time = time.time() diff --git a/rock/sandbox/operator/abstract.py b/rock/sandbox/operator/abstract.py index b54afd77cb..0440a98964 100644 --- a/rock/sandbox/operator/abstract.py +++ b/rock/sandbox/operator/abstract.py @@ -40,6 +40,33 @@ async def stop(self, sandbox_id: str, reason: StopReason = StopReason.MANUAL) -> async def delete(self, config: DeploymentConfig, host_ip: str | None = None) -> bool: ... + def supports_archive(self) -> bool: + return False + + async def start_archive( + self, + config: DeploymentConfig, + host_ip: str | None, + dir_storage_config: dict, + image_storage_config: dict, + archive_params: dict | None = None, + ) -> None: + from rock.sdk.common.exceptions import BadRequestRockError + + raise BadRequestRockError(f"archive not supported on {type(self).__name__}") + + async def start_restore( + self, + config: DeploymentConfig, + host_ip: str | None, + dir_storage_config: dict, + image_storage_config: dict, + archive_params: dict | None = None, + ) -> None: + from rock.sdk.common.exceptions import BadRequestRockError + + raise BadRequestRockError(f"restore not supported on {type(self).__name__}") + def set_redis_provider(self, redis_provider: RedisProvider): self._redis_provider = redis_provider diff --git a/rock/sandbox/operator/ray.py b/rock/sandbox/operator/ray.py index 0f03f31d72..e73f1ca504 100644 --- a/rock/sandbox/operator/ray.py +++ b/rock/sandbox/operator/ray.py @@ -1,3 +1,4 @@ +import asyncio import json import ray @@ -32,23 +33,26 @@ def __init__(self, ray_service: RayService, runtime_config: RuntimeConfig): def _get_actor_name(self, sandbox_id: str) -> str: return f"sandbox-{sandbox_id}" - async def create_actor(self, config: DockerDeploymentConfig, pin_to_host_ip: str | None = None): + async def create_actor( + self, + config: DockerDeploymentConfig, + pin_to_host_ip: str | None = None, + ): actor_options = self._generate_actor_options(config, pin_to_host_ip=pin_to_host_ip) deployment: DockerDeployment = config.get_deployment() sandbox_actor = SandboxActor.options(**actor_options).remote(config, deployment) return sandbox_actor - def _generate_actor_options(self, config: DockerDeploymentConfig, pin_to_host_ip: str | None = None) -> dict: + def _generate_actor_options( + self, + config: DockerDeploymentConfig, + pin_to_host_ip: str | None = None, + ) -> dict: actor_name = self._get_actor_name(config.container_name) actor_options = {"name": actor_name, "lifetime": "detached"} try: - memory = parse_size_to_bytes(config.memory) actor_options["num_cpus"] = config.cpus - actor_options["memory"] = memory - # Pin to a specific node via Ray's implicit `node:` resource - # (registered automatically per node with value 1.0; we consume a - # negligible 0.001 so we don't block other actors). Used by restart - # to land on the host that owns the existing container. + actor_options["memory"] = parse_size_to_bytes(config.memory) if pin_to_host_ip: actor_options["resources"] = {f"node:{pin_to_host_ip}": 0.001} return actor_options @@ -225,3 +229,50 @@ def use_rocklet(self) -> bool: if self._nacos_provider.get_switch_status(GET_STATUS_SWITCH): return True return False + + def supports_archive(self) -> bool: + return True + + async def start_archive( + self, + config: DockerDeploymentConfig, + host_ip: str | None, + dir_storage_config: dict, + image_storage_config: dict, + archive_params: dict | None = None, + ) -> None: + async with self._ray_service.get_ray_rwlock().read_lock(): + config.cpus = 0.1 + config.memory = "256m" + sandbox_actor = await self.create_actor(config, pin_to_host_ip=host_ip) + asyncio.create_task( + self._run_archive_and_kill(sandbox_actor, dir_storage_config, image_storage_config, archive_params) + ) + + async def _run_archive_and_kill(self, actor, dir_cfg, image_cfg, archive_params=None): + logger.info("_run_archive_and_kill started") + try: + await self._ray_service.async_ray_get(actor.archive.remote(dir_cfg, image_cfg, archive_params)) + logger.info("_run_archive_and_kill completed successfully") + except Exception as e: + logger.exception(f"archive remote failed: {e}") + finally: + ray.kill(actor) + + async def start_restore( + self, + config: DockerDeploymentConfig, + host_ip: str | None, + dir_storage_config: dict, + image_storage_config: dict, + archive_params: dict | None = None, + ) -> None: + """Fire-and-forget restore: actor pulls image, downloads logs, starts container. + + The actor stays alive as the long-lived detached actor for the sandbox + (same as what `submit` creates for a new sandbox). `get_status` alive + detection will transition the state to RUNNING once the container is up. + """ + async with self._ray_service.get_ray_rwlock().read_lock(): + sandbox_actor = await self.create_actor(config, pin_to_host_ip=host_ip) + sandbox_actor.restore_and_start.remote(dir_storage_config, image_storage_config, archive_params) diff --git a/rock/sandbox/sandbox_actor.py b/rock/sandbox/sandbox_actor.py index f34aea31ed..9313723315 100644 --- a/rock/sandbox/sandbox_actor.py +++ b/rock/sandbox/sandbox_actor.py @@ -1,6 +1,7 @@ import asyncio import datetime import os +import shutil import socket import subprocess import tempfile @@ -8,6 +9,7 @@ import ray from fastapi import UploadFile +from rock import env_vars from rock.actions import ( BashObservation, CloseBashSessionResponse, @@ -32,11 +34,24 @@ from rock.deployments.docker import DockerDeployment from rock.deployments.status import ServiceStatus from rock.logger import init_logger +from rock.sandbox.archive._constants import dir_archive_key, image_ref +from rock.sandbox.archive.oss_storage import OssDirStorage +from rock.sandbox.archive.registry_v2 import DockerRegistryV2ImageStorage +from rock.sandbox.archive.s3_storage import S3DirStorage from rock.sandbox.gem_actor import GemActor +from rock.utils.format import parse_size_to_bytes logger = init_logger(__name__) +def _make_dir_storage(config: dict): + config = dict(config) + storage_type = config.pop("type", "oss") + if storage_type == "s3": + return S3DirStorage(**config) + return OssDirStorage(**config) + + @ray.remote(scheduling_strategy="SPREAD") class SandboxActor(GemActor): _export_interval_millis: int = 10000 @@ -127,6 +142,14 @@ async def _run_shell_command( await process.wait() raise subprocess.TimeoutExpired(args, timeout) + async def _get_image_size(self, image_tag: str) -> int: + result = await self._run_shell_command("docker", "image", "inspect", "--format={{.Size}}", image_tag) + return int(result.stdout.decode().strip()) + + async def _get_dir_size(self, dir_path: str) -> int: + result = await self._run_shell_command("du", "-sb", dir_path) + return int(result.stdout.decode().split()[0]) + async def start(self): try: await self._deployment.start() @@ -316,3 +339,106 @@ async def sandbox_info(self) -> SandboxInfo: "disk_limit_rootfs": self._config.disk_limit_rootfs, } return {} + + async def archive( + self, + dir_storage_config: dict, + image_storage_config: dict, + archive_params: dict | None = None, + ) -> None: + """Async archive: commit+push image, then tar+upload log dir.""" + sandbox_id = self._config.container_name + dir_storage = _make_dir_storage(dir_storage_config) + image_storage = DockerRegistryV2ImageStorage(**image_storage_config) + archive_params = archive_params or {} + prefix = archive_params.get("archive_prefix", "rock-archives/") + acr_ns = archive_params.get("acr_namespace", "sandbox_archive") + + local_tag = f"archive-staging-{sandbox_id}:latest" + await self._run_shell_command("docker", "commit", sandbox_id, local_tag) + + max_image = archive_params.get("max_image_push_size", "") + if max_image: + image_size = await self._get_image_size(local_tag) + max_bytes = parse_size_to_bytes(max_image) + if image_size > max_bytes: + await self._run_shell_command("docker", "rmi", local_tag, check=False) + raise RuntimeError( + f"[{sandbox_id}] image size {image_size} bytes exceeds limit {max_image} ({max_bytes} bytes)" + ) + + ref = image_ref(sandbox_id, image_storage.registry_url, acr_ns) + try: + await image_storage.push_from_local(local_tag, ref) + except Exception: + await self._run_shell_command("docker", "rmi", local_tag, check=False) + raise + await self._run_shell_command("docker", "rmi", local_tag, check=False) + + log_root = env_vars.ROCK_LOGGING_PATH + if not log_root: + logger.info(f"[{sandbox_id}] ROCK_LOGGING_PATH not set, skipping log dir archive") + return + + log_dir = f"{log_root}/{sandbox_id}" + key = dir_archive_key(sandbox_id, prefix) + + if os.path.isdir(log_dir): + max_dir = archive_params.get("max_dir_upload_size", "") + if max_dir: + dir_size = await self._get_dir_size(log_dir) + max_bytes = parse_size_to_bytes(max_dir) + if dir_size > max_bytes: + raise RuntimeError( + f"[{sandbox_id}] log dir size {dir_size} bytes exceeds limit {max_dir} ({max_bytes} bytes)" + ) + try: + await dir_storage.upload_dir(log_dir, key) + except Exception: + await image_storage.delete(ref) + raise + shutil.rmtree(log_dir, ignore_errors=True) + else: + logger.info(f"[{sandbox_id}] log dir {log_dir} not found, skipping log archive") + + async def restore_and_start( + self, + dir_storage_config: dict, + image_storage_config: dict, + archive_params: dict | None = None, + ) -> None: + """Full restore: pull image + download logs + docker start + arm watchdog.""" + sandbox_id = self._config.container_name + dir_storage = _make_dir_storage(dir_storage_config) + image_storage = DockerRegistryV2ImageStorage(**image_storage_config) + archive_params = archive_params or {} + prefix = archive_params.get("archive_prefix", "rock-archives/") + acr_ns = archive_params.get("acr_namespace", "sandbox_archive") + + ref = image_ref(sandbox_id, image_storage.registry_url, acr_ns) + await image_storage.pull_to_local(ref) + + log_root = env_vars.ROCK_LOGGING_PATH + if log_root: + key = dir_archive_key(sandbox_id, prefix) + target_dir = f"{log_root}/{sandbox_id}" + if await dir_storage.exists(key): + try: + if os.path.exists(target_dir): + shutil.rmtree(target_dir) + await dir_storage.download_to_dir(key, target_dir) + except Exception as e: + logger.warning(f"[{sandbox_id}] log restore failed, continuing without logs: {e}") + else: + logger.warning(f"[{sandbox_id}] archived log key {key} not found in OSS, skipping log restore") + else: + logger.info(f"[{sandbox_id}] ROCK_LOGGING_PATH not set, skipping log restore") + + if isinstance(self._deployment, DockerDeployment): + await self._deployment.restart_from_image(ref) + else: + await self._deployment.restart() + if isinstance(self._deployment, DockerDeployment): + self._clean_container_background() + await self._setup_monitor() + logger.info(f"[{sandbox_id}] restore_and_start complete") diff --git a/rock/sandbox/sandbox_manager.py b/rock/sandbox/sandbox_manager.py index 1c1ff3070a..b124de337a 100644 --- a/rock/sandbox/sandbox_manager.py +++ b/rock/sandbox/sandbox_manager.py @@ -1,5 +1,7 @@ import asyncio +import datetime import time +from datetime import timezone from fastapi import UploadFile @@ -31,6 +33,7 @@ from rock.logger import init_logger from rock.rocklet import __version__ as swe_version from rock.sandbox import __version__ as gateway_version +from rock.sandbox.archive._constants import image_ref from rock.sandbox.base_manager import BaseManager from rock.sandbox.operator.abstract import AbstractOperator from rock.sandbox.sandbox_actor import SandboxActor @@ -67,6 +70,8 @@ def __init__( self._ray_service = ray_service self._ray_namespace = ray_namespace self._operator = operator + self._dir_storage = None + self._image_storage = None self._aes_encrypter = AESEncryption() self._proxy_service = SandboxProxyService(rock_config=rock_config, meta_store=meta_store) logger.info("sandbox service init success") @@ -158,16 +163,18 @@ async def restart_async(self, sandbox_id: str) -> SandboxStartResponse: raise BadRequestRockError(f"Sandbox {sandbox_id} not found") state = sm.current_state.value - if state != State.STOPPED: + if state == State.ARCHIVED: + await self.restart_from_archived(sandbox_id) + elif state == State.STOPPED: + await sm.send( + "restart", + sandbox_id=sandbox_id, + operator=self._operator, + meta_store=self._meta_store, + ) + else: raise BadRequestRockError(f"Sandbox {sandbox_id} cannot be restarted: current state is '{state.value}'") - await sm.send( - "restart", - sandbox_id=sandbox_id, - operator=self._operator, - meta_store=self._meta_store, - ) - info: SandboxInfo = sm.sandbox_info or {} return SandboxStartResponse( sandbox_id=sandbox_id, @@ -233,16 +240,19 @@ async def delete(self, sandbox_id: str, reason: DeleteReason = DeleteReason.MANU if state == State.DELETED: logger.info(f"delete: sandbox {sandbox_id} already deleted, noop") return - if state != State.STOPPED: + if state not in (State.STOPPED, State.ARCHIVED): raise BadRequestRockError( - f"Sandbox {sandbox_id} cannot be deleted: current state is '{state.value}', must be stopped first" + f"Sandbox {sandbox_id} cannot be deleted: current state is '{state.value}', must be stopped or archived first" ) + await sm.send( "delete", sandbox_id=sandbox_id, operator=self._operator, meta_store=self._meta_store, reason=reason, + dir_storage=self._dir_storage, + image_storage=self._image_storage, ) async def get_mount(self, sandbox_id): @@ -270,31 +280,35 @@ async def commit(self, sandbox_id, image_tag: str, username: str, password: str) logger.info(f"commit {sandbox_id} to {image_tag} finished, result {result}") return result + async def _try_advance_pending(self, sandbox_id: str, sm) -> dict | None: + """Probe operator alive; fire ``alive`` transition if RUNNING. Returns operator info or None.""" + operator_sandbox_info = await self._operator.get_status(sandbox_id=sandbox_id) + if operator_sandbox_info is None: + return None + is_alive = operator_sandbox_info.get("state") == State.RUNNING + if sm.current_state.value == State.PENDING and is_alive: + await sm.send( + "alive", sandbox_id=sandbox_id, meta_store=self._meta_store, sandbox_info=operator_sandbox_info + ) + if operator_sandbox_info.get("state") in (State.PENDING, State.RUNNING): + await self._refresh_timeout(sandbox_id) + return operator_sandbox_info + @monitor_sandbox_operation() async def get_status(self, sandbox_id, include_all_states: bool = False) -> SandboxStatusResponse: - # get status from meta_store sm = await self._get_current_statemachine(sandbox_id) if sm is None: raise BadRequestRockError(f"Sandbox {sandbox_id} not found") - # update status from operator - is_alive = False - operator_sandbox_info: SandboxInfo | None = await self._operator.get_status(sandbox_id=sandbox_id) - if operator_sandbox_info is not None: - is_alive = operator_sandbox_info.get("state") == State.RUNNING - if sm.current_state.value == State.PENDING and is_alive: - await sm.send( - "alive", sandbox_id=sandbox_id, meta_store=self._meta_store, sandbox_info=operator_sandbox_info - ) - if operator_sandbox_info.get("state") in (State.PENDING, State.RUNNING): - await self._refresh_timeout(sandbox_id) + operator_sandbox_info = await self._try_advance_pending(sandbox_id, sm) + is_alive = operator_sandbox_info is not None and operator_sandbox_info.get("state") == State.RUNNING # compat with legacy get_status behavior by default (include_all_states == False), # raise 'not found' if not on pending or running status. if not include_all_states and sm.current_state.value not in (State.PENDING, State.RUNNING): raise BadRequestRockError(f"Sandbox {sandbox_id} not found") - if operator_sandbox_info is not None: + if operator_sandbox_info is not None and sm.current_state.value in (State.PENDING, State.RUNNING): sandbox_info = operator_sandbox_info else: sandbox_info = sm.sandbox_info @@ -393,20 +407,102 @@ async def _is_actor_alive(self, sandbox_id): logger.error("get actor failed", exc_info=e) return False - async def _check_job_background(self): - logger.debug("check job background") + async def _auto_transition(self): + """Long-interval scan: expire RUNNING/PENDING → STOPPED, auto-delete/archive stale STOPPED.""" + logger.debug("auto_transition") async for sandbox_id in self._meta_store.iter_alive_sandbox_ids(): try: - is_expired = await self._is_expired(sandbox_id) - if is_expired: - logger.info(f"sandbox_id:[{sandbox_id}] is expired, start to stop") + if await self._is_expired(sandbox_id): + logger.info(f"[auto_transition] {sandbox_id} expired, stopping") asyncio.create_task(self.stop(sandbox_id, reason=StopReason.EXPIRED)) - except asyncio.CancelledError as e: - logger.error("check_job_background CancelledError", exc_info=e) + except asyncio.CancelledError: + continue + except Exception as e: + logger.error(f"[auto_transition] {sandbox_id}: {e}", exc_info=True) + continue + + deleted_ids = await self._auto_delete_stopped() + await self._auto_archive_stopped(exclude_ids=deleted_ids) + + async def _auto_delete_stopped(self) -> set[str]: + """Delete STOPPED sandboxes that have been idle longer than auto_delete_after_sec. + + Returns the set of sandbox IDs that were deleted so the caller can skip + them in subsequent lifecycle actions (e.g. auto-archive). + """ + auto_delete_sec = self.rock_config.lifecycle.auto_delete_after_sec + if not auto_delete_sec: + return set() + + try: + stopped_list = await self._meta_store.list_by("state", State.STOPPED.value) + except Exception as e: + logger.warning(f"[auto_delete] list_by failed: {e}") + return set() + + deleted: set[str] = set() + now = datetime.datetime.now(timezone.utc) + for info in stopped_list: + sandbox_id = info.get("sandbox_id", "") + stop_time_str = info.get("stop_time", "") + if not sandbox_id or not stop_time_str: + continue + try: + stop_time = datetime.datetime.fromisoformat(stop_time_str.replace("Z", "+00:00")) + elapsed = (now - stop_time).total_seconds() + except (ValueError, TypeError): + continue + if elapsed < auto_delete_sec: continue + try: + logger.info(f"[auto_delete] {sandbox_id} stopped for {int(elapsed)}s, deleting") + await self.delete(sandbox_id, reason=DeleteReason.EXPIRED) + deleted.add(sandbox_id) except Exception as e: - logger.error("check_job_background Exception", exc_info=e) + logger.error(f"[auto_delete] {sandbox_id}: {e}", exc_info=True) + return deleted + + async def _auto_archive_stopped(self, *, exclude_ids: set[str] | None = None) -> None: + """Archive STOPPED sandboxes that have been idle longer than auto_archive_after_sec.""" + auto_archive_sec = self.rock_config.lifecycle.auto_archive_after_sec + if not auto_archive_sec: + return + if not self._operator or not self._operator.supports_archive(): + return + if not self._dir_storage or not self._image_storage: + return + + try: + stopped_list = await self._meta_store.list_by("state", State.STOPPED.value) + except Exception as e: + logger.warning(f"[auto_archive] list_by failed: {e}") + return + + now = datetime.datetime.now(timezone.utc) + for info in stopped_list: + sandbox_id = info.get("sandbox_id", "") + stop_time_str = info.get("stop_time", "") + if not sandbox_id or not stop_time_str: continue + if exclude_ids and sandbox_id in exclude_ids: + continue + try: + stop_time = datetime.datetime.fromisoformat(stop_time_str.replace("Z", "+00:00")) + elapsed = (now - stop_time).total_seconds() + except (ValueError, TypeError): + continue + if elapsed < auto_archive_sec: + continue + try: + logger.info(f"[auto_archive] {sandbox_id} stopped for {int(elapsed)}s, archiving") + await self.archive_sandbox(sandbox_id) + except Exception as e: + logger.error(f"[auto_archive] {sandbox_id}: {e}", exc_info=True) + + async def _reconcile(self) -> None: + """Reconcile intermediate states (PENDING, ARCHIVING) on short interval.""" + await self._reconcile_pending() + await self._reconcile_archiving() async def get_sandbox_statistics(self, sandbox_id): actor_name = self.deployment_manager.get_actor_name(sandbox_id) @@ -436,6 +532,172 @@ def validate_sandbox_spec(self, runtime_config: RuntimeConfig, deployment_config parse_size_to_bytes(deployment_config.disk_limit_rootfs) except ValueError as e: logger.warning(f"Invalid disk_limit_rootfs size: {deployment_config.disk_limit_rootfs}", exc_info=e) - raise BadRequestRockError( - f"Invalid disk_limit_rootfs size: {deployment_config.disk_limit_rootfs}" - ) + raise BadRequestRockError(f"Invalid disk_limit_rootfs size: {deployment_config.disk_limit_rootfs}") + + async def archive_sandbox(self, sandbox_id: str) -> None: + """Validate preconditions, then fire archive transition (cleanup + actor dispatch in on_archive).""" + if not self._operator or not self._operator.supports_archive(): + raise BadRequestRockError(f"archive not supported on {type(self._operator).__name__}") + if not self._dir_storage or not self._image_storage: + raise BadRequestRockError("archive not configured: missing storage credentials") + + sm = await self._get_current_statemachine(sandbox_id) + if sm is None: + raise BadRequestRockError(f"sandbox {sandbox_id} not found") + + archive_cfg = self.rock_config.lifecycle.archive + archive_params = { + "archive_prefix": archive_cfg.prefix, + "acr_namespace": archive_cfg.acr.namespace, + "max_image_push_size": archive_cfg.max_image_push_size, + "max_dir_upload_size": archive_cfg.max_dir_upload_size, + } + await sm.send( + "archive", + sandbox_id=sandbox_id, + meta_store=self._meta_store, + operator=self._operator, + dir_storage=self._dir_storage, + image_storage=self._image_storage, + archive_params=archive_params, + ) + + async def restart_from_archived(self, sandbox_id: str) -> None: + """Async restart from ARCHIVED: transition to PENDING, fire-and-forget actor. + + The actor handles pull + download + docker start. get_status alive detection + drives the PENDING → RUNNING transition. + """ + if not self._dir_storage or not self._image_storage: + raise BadRequestRockError("archive not configured: missing storage credentials") + + sm = await self._get_current_statemachine(sandbox_id) + if sm is None: + raise BadRequestRockError(f"sandbox {sandbox_id} not found") + + if sm.current_state.value != State.ARCHIVED: + return + + info = sm.sandbox_info or {} + spec = info.get("spec") or {} + if not spec: + raise BadRequestRockError(f"sandbox {sandbox_id} has no spec snapshot; cannot restore") + + timeout_info = SandboxTimeoutHelper.make_timeout_info(DockerDeploymentConfig(**spec).auto_clear_time) + + await sm.send( + "restore", + sandbox_id=sandbox_id, + meta_store=self._meta_store, + timeout_info=timeout_info, + operator=self._operator, + dir_storage=self._dir_storage, + image_storage=self._image_storage, + ) + + async def _reconcile_archiving(self) -> None: + """Reconcile ARCHIVING sandboxes: verify completion or roll back to STOPPED on timeout.""" + if not self._operator or not self._operator.supports_archive(): + return + if not self._dir_storage or not self._image_storage: + return + + try: + archiving = await self._meta_store.list_by("state", State.ARCHIVING.value) + except Exception as e: + logger.warning(f"[reconcile_archiving] list_by failed: {e}") + return + + for info in archiving: + sandbox_id = info.get("sandbox_id", "") + if not sandbox_id: + continue + + acr_ns = info.get("acr_namespace", self.rock_config.lifecycle.archive.acr.namespace) + ref = image_ref(sandbox_id, self._image_storage.registry_url, acr_ns) + + try: + img_ok = await self._image_storage.exists(ref) + except Exception as e: + logger.warning(f"exists check failed for {sandbox_id}: {e}") + continue + + if img_ok: + try: + sm = await self._get_current_statemachine(sandbox_id) + if sm and sm.current_state.value == State.ARCHIVING: + await sm.send("archive_done", sandbox_id=sandbox_id, meta_store=self._meta_store) + logger.info(f"archive_done: {sandbox_id}") + except Exception as e: + logger.error(f"archive_done transition failed for {sandbox_id}: {e}", exc_info=True) + continue + + started_at = info.get("state_enter_time", "") + if not started_at: + continue + try: + started = datetime.datetime.fromisoformat(started_at.replace("Z", "+00:00")) + elapsed = (datetime.datetime.now(timezone.utc) - started).total_seconds() + except (ValueError, TypeError): + continue + + if elapsed < self.rock_config.lifecycle.archive_timeout_sec: + continue + + try: + sm = await self._get_current_statemachine(sandbox_id) + if sm and sm.current_state.value == State.ARCHIVING: + await sm.send( + "archive_failed", + sandbox_id=sandbox_id, + meta_store=self._meta_store, + reason=f"timeout after {int(elapsed)}s", + ) + logger.warning(f"archive_failed: {sandbox_id} ({int(elapsed)}s)") + except Exception as e: + logger.error(f"archive_failed transition failed for {sandbox_id}: {e}", exc_info=True) + + async def _reconcile_pending(self) -> None: + """Reconcile PENDING sandboxes: advance to RUNNING or timeout restore back to ARCHIVED.""" + try: + pending_list = await self._meta_store.list_by("state", State.PENDING.value) + except Exception as e: + logger.warning(f"[reconcile_pending] list_by failed: {e}") + return + + for info in pending_list: + sandbox_id = info.get("sandbox_id", "") + if not sandbox_id: + continue + try: + # 1. Restore timeout (archive_time present = restoring from ARCHIVED) + started_at = info.get("state_enter_time", "") + if info.get("archive_time") and started_at: + try: + started = datetime.datetime.fromisoformat(started_at.replace("Z", "+00:00")) + elapsed = (datetime.datetime.now(timezone.utc) - started).total_seconds() + except (ValueError, TypeError): + elapsed = 0 + if elapsed >= self.rock_config.lifecycle.restore_timeout_sec: + sm = await self._get_current_statemachine(sandbox_id) + if sm and sm.current_state.value == State.PENDING: + await sm.send( + "restore_failed", + sandbox_id=sandbox_id, + meta_store=self._meta_store, + reason=f"timeout after {int(elapsed)}s", + ) + logger.warning(f"[reconcile_pending] restore_failed: {sandbox_id} ({int(elapsed)}s)") + continue + + # 2. Try advance PENDING → RUNNING + sm = await self._get_current_statemachine(sandbox_id) + if sm is None: + continue + await self._try_advance_pending(sandbox_id, sm) + + except asyncio.CancelledError: + continue + except Exception as e: + logger.error(f"[reconcile_pending] {sandbox_id}: {e}", exc_info=True) + continue diff --git a/rock/sandbox/sandbox_meta_store.py b/rock/sandbox/sandbox_meta_store.py index 9b04f66a3f..c492a6bd62 100644 --- a/rock/sandbox/sandbox_meta_store.py +++ b/rock/sandbox/sandbox_meta_store.py @@ -135,7 +135,8 @@ async def get(self, sandbox_id: str, check_db: bool = False) -> SandboxInfo | No if check_db: result = await self._db.get(sandbox_id) if result: - return result + status_blob = result.pop("status", None) or {} + return {**status_blob, **result} return None async def exists(self, sandbox_id: str) -> bool: diff --git a/rock/sandbox/sandbox_statemachine.py b/rock/sandbox/sandbox_statemachine.py index 105b2423d9..f3b307e975 100644 --- a/rock/sandbox/sandbox_statemachine.py +++ b/rock/sandbox/sandbox_statemachine.py @@ -15,6 +15,7 @@ from rock.common.constants import DeleteReason, StopReason from rock.deployments.config import DockerDeploymentConfig from rock.logger import init_logger +from rock.sandbox.archive._constants import dir_archive_key, image_ref from rock.sandbox.utils.timeout import SandboxTimeoutHelper from rock.sdk.common.exceptions import BadRequestRockError from rock.utils.system import get_iso8601_timestamp @@ -46,6 +47,8 @@ class SandboxStateMachine(StateChart): pending = SMState("Pending", initial=True, value=RockState.PENDING) running = SMState("Running", value=RockState.RUNNING) stopped = SMState("Stopped", value=RockState.STOPPED) + archiving = SMState("Archiving", value=RockState.ARCHIVING) + archived = SMState("Archived", value=RockState.ARCHIVED) deleted = SMState("Deleted", final=True, value=RockState.DELETED) # Transitions @@ -53,7 +56,12 @@ class SandboxStateMachine(StateChart): stop_noop = stopped.to(stopped) alive = pending.to(running) restart = stopped.to(pending) - delete = stopped.to(deleted) + delete = stopped.to(deleted) | archived.to(deleted) + archive = stopped.to(archiving) + archive_done = archiving.to(archived) + archive_failed = archiving.to(stopped) + restore = archived.to(pending) + restore_failed = pending.to(archived) def __init__(self, **kwargs): """Initialize with optional sandbox_info.""" @@ -145,6 +153,8 @@ async def on_delete( operator, meta_store, reason: DeleteReason = DeleteReason.MANUAL, + dir_storage=None, + image_storage=None, ) -> None: logger.info(f"delete sandbox {sandbox_id} (reason={reason.value})") sandbox_info = self.sandbox_info or {} @@ -168,11 +178,124 @@ async def on_delete( "rely on ContainerCleanupTask to reap docker container" ) + if sandbox_info.get("archive_time") and dir_storage and image_storage: + prefix = sandbox_info.get("archive_prefix", "rock-archives/") + acr_ns = sandbox_info.get("acr_namespace", "sandbox_archive") + key = dir_archive_key(sandbox_id, prefix) + ref = image_ref(sandbox_id, image_storage.registry_url, acr_ns) + try: + await dir_storage.delete(key) + except Exception as e: + logger.warning(f"delete: cleanup archive dir {key} failed: {e}") + try: + await image_storage.delete(ref) + except Exception as e: + logger.warning(f"delete: cleanup archive image {ref} failed: {e}") + sandbox_info["state"] = RockState.DELETED sandbox_info["delete_time"] = get_iso8601_timestamp() await meta_store.archive(sandbox_id, sandbox_info) self.sandbox_info = sandbox_info + async def on_archive( + self, + sandbox_id: str, + meta_store, + operator=None, + dir_storage=None, + image_storage=None, + archive_params: dict | None = None, + ) -> None: + logger.info(f"archive sandbox {sandbox_id}") + sandbox_info = self.sandbox_info or {} + archive_params = archive_params or {} + prefix = archive_params.get("archive_prefix", "rock-archives/") + acr_ns = archive_params.get("acr_namespace", "sandbox_archive") + + sandbox_info["state"] = RockState.ARCHIVING + sandbox_info["archive_prefix"] = prefix + sandbox_info["acr_namespace"] = acr_ns + sandbox_info["state_enter_time"] = get_iso8601_timestamp() + await meta_store.archive(sandbox_id, sandbox_info) + self.sandbox_info = sandbox_info + + if operator: + spec = sandbox_info.get("spec") or {} + config = DockerDeploymentConfig(**spec) + await operator.start_archive( + config=config, + host_ip=sandbox_info.get("host_ip"), + dir_storage_config=dir_storage.client_config, + image_storage_config=image_storage.client_config, + archive_params=archive_params, + ) + + async def on_archive_done(self, sandbox_id: str, meta_store) -> None: + logger.info(f"archive done sandbox {sandbox_id}") + sandbox_info = self.sandbox_info or {} + sandbox_info["state"] = RockState.ARCHIVED + sandbox_info["archive_time"] = get_iso8601_timestamp() + sandbox_info.pop("state_enter_time", None) + await meta_store.archive(sandbox_id, sandbox_info) + self.sandbox_info = sandbox_info + + async def on_archive_failed(self, sandbox_id: str, meta_store, reason: str = "") -> None: + logger.info(f"archive failed sandbox {sandbox_id}: {reason}") + sandbox_info = self.sandbox_info or {} + sandbox_info["state"] = RockState.STOPPED + sandbox_info.pop("archive_time", None) + sandbox_info.pop("state_enter_time", None) + await meta_store.archive(sandbox_id, sandbox_info) + self.sandbox_info = sandbox_info + + async def on_restore( + self, + sandbox_id: str, + meta_store, + timeout_info: dict | None = None, + operator=None, + dir_storage=None, + image_storage=None, + ) -> None: + """ARCHIVED → PENDING: fire-and-forget actor does pull+download+docker start. + + Writes to Redis so that operator.get_status can probe alive status + (PENDING alive detection drives the PENDING → RUNNING transition). + """ + logger.info(f"restore sandbox {sandbox_id}") + sandbox_info = self.sandbox_info or {} + sandbox_info["state"] = RockState.PENDING + sandbox_info["state_enter_time"] = get_iso8601_timestamp() + sandbox_info.pop("stop_time", None) + await meta_store.update(sandbox_id, sandbox_info) + if timeout_info: + await meta_store.update_timeout(sandbox_id, timeout_info) + self.sandbox_info = sandbox_info + + if operator: + spec = sandbox_info.get("spec") or {} + archive_params = { + "archive_prefix": sandbox_info.get("archive_prefix", "rock-archives/"), + "acr_namespace": sandbox_info.get("acr_namespace", "sandbox_archive"), + } + config = DockerDeploymentConfig(**spec) + await operator.start_restore( + config=config, + host_ip=sandbox_info.get("host_ip"), + dir_storage_config=dir_storage.client_config, + image_storage_config=image_storage.client_config, + archive_params=archive_params, + ) + + async def on_restore_failed(self, sandbox_id: str, meta_store, reason: str = "") -> None: + """PENDING → ARCHIVED: timeout or unrecoverable error during restore.""" + logger.info(f"restore failed sandbox {sandbox_id}: {reason}") + sandbox_info = self.sandbox_info or {} + sandbox_info["state"] = RockState.ARCHIVED + sandbox_info.pop("state_enter_time", None) + await meta_store.archive(sandbox_id, sandbox_info) + self.sandbox_info = sandbox_info + @classmethod async def from_state_value(cls, state_value: str | None, sandbox_info: SandboxInfo) -> "SandboxStateMachine": """Create a state machine restored to *state_value* (from Redis/DB).""" @@ -180,6 +303,8 @@ async def from_state_value(cls, state_value: str | None, sandbox_info: SandboxIn RockState.PENDING: "pending", RockState.RUNNING: "running", RockState.STOPPED: "stopped", + RockState.ARCHIVING: "archiving", + RockState.ARCHIVED: "archived", RockState.DELETED: "deleted", } sm = ( diff --git a/rock/sdk/sandbox/client.py b/rock/sdk/sandbox/client.py index 23290da9cd..5a9dd54fd4 100644 --- a/rock/sdk/sandbox/client.py +++ b/rock/sdk/sandbox/client.py @@ -318,6 +318,27 @@ async def restart(self): f"Failed to restart sandbox within {self.config.startup_timeout}s, sandbox: {str(self)}" ) + async def archive(self): + """Archive a stopped sandbox (snapshot container + upload logs to S3). + + The sandbox must be in STOPPED state. After this call the server transitions + it to ARCHIVING, and a background scanner will move it to ARCHIVED once + both the image and log uploads are confirmed. + """ + if not self.sandbox_id: + raise Exception("sandbox_id is not set, cannot archive") + url = f"{self._url}/archive" + headers = self._build_headers() + data = {"sandbox_id": self.sandbox_id} + response = await HttpUtils.post(url, headers, data) + logging.debug(f"Archive sandbox response: {response}") + if "Success" != response.get("status"): + result = response.get("result", None) + if result is not None: + rock_response = SandboxResponse(**result) + raise_for_code(rock_response.code, f"Failed to archive sandbox: {response}") + raise Exception(f"Failed to archive sandbox: {response}") + async def commit(self, image_tag: str, username: str, password: str): if not self.sandbox_id: return diff --git a/scripts/gen_ddl.py b/scripts/gen_ddl.py index 584bb25b21..64aa09f632 100644 --- a/scripts/gen_ddl.py +++ b/scripts/gen_ddl.py @@ -3,10 +3,14 @@ Usage: uv run python scripts/gen_ddl.py # default: postgresql uv run python scripts/gen_ddl.py --dialect sqlite - uv run python scripts/gen_ddl.py --dialect postgresql --out ddl.sql + uv run python scripts/gen_ddl.py --table sandbox_record --out sql/sandbox_record.sql + uv run python scripts/gen_ddl.py --alter-from sql/sandbox_record.sql # diff against old DDL file + uv run python scripts/gen_ddl.py --alter-from HEAD~1 # diff against git commit/tag + uv run python scripts/gen_ddl.py --alter-from v1.3.0 # diff against git tag """ import argparse +import re import sys from sqlalchemy.schema import CreateIndex, CreateTable @@ -25,12 +29,20 @@ def get_dialect(name: str): sys.exit(1) -def gen_ddl(dialect) -> str: +def gen_ddl(dialect, table_filter: str | None = None) -> str: # Import here so all models are registered onto Base.metadata from rock.admin.core.schema import Base # noqa: F401 (side-effect: registers SandboxRecord) + tables = Base.metadata.sorted_tables + if table_filter: + names = {n.strip() for n in table_filter.split(",")} + tables = [t for t in tables if t.name in names] + if not tables: + print(f"No tables matched: {table_filter}", file=sys.stderr) + sys.exit(1) + lines: list[str] = [] - for table in Base.metadata.sorted_tables: + for table in tables: lines.append(str(CreateTable(table).compile(dialect=dialect)).strip() + ";") # table.indexes has set-like semantics; sort for deterministic output. for index in sorted(table.indexes, key=lambda idx: idx.name or ""): @@ -39,21 +51,120 @@ def gen_ddl(dialect) -> str: return "\n\n".join(lines) +def _parse_columns_from_sql(sql: str) -> dict[str, str]: + """Parse column definitions from a CREATE TABLE statement. Returns {column_name: full_definition}.""" + match = re.search(r"CREATE TABLE (\w+)\s*\((.*)\)", sql, re.DOTALL) + if not match: + return {} + body = match.group(2) + columns: dict[str, str] = {} + for line in body.split("\n"): + line = line.strip().rstrip(",") + if not line or line.startswith("PRIMARY KEY") or line.startswith("CONSTRAINT") or line.startswith(")"): + continue + parts = line.split() + if parts: + columns[parts[0].lower()] = line + return columns + + +def _parse_indexes_from_sql(sql: str) -> dict[str, str]: + """Parse CREATE INDEX statements. Returns {index_name: full_statement}.""" + indexes: dict[str, str] = {} + for match in re.finditer(r"(CREATE\s+(?:UNIQUE\s+)?INDEX\s+(\w+)\s+[^;]+);", sql, re.IGNORECASE): + indexes[match.group(2).lower()] = match.group(1).strip() + ";" + return indexes + + +def _read_old_sql(source: str) -> str: + """Read old DDL from a file path or git ref (commit/tag). + + If source looks like a file path (contains '/' or '.sql'), read it directly. + Otherwise treat it as a git ref and read sql/sandbox_record.sql from that ref. + """ + if "/" in source or source.endswith(".sql"): + with open(source) as f: + return f.read() + import subprocess + + # Try each known SQL path under the git ref + for sql_path in ["sql/sandbox_record.sql"]: + try: + return subprocess.check_output( + ["git", "show", f"{source}:{sql_path}"], stderr=subprocess.DEVNULL, text=True + ) + except subprocess.CalledProcessError: + continue + print(f"Cannot find DDL file in git ref '{source}'", file=sys.stderr) + sys.exit(1) + + +def gen_alter(dialect, source: str, table_filter: str | None = None) -> str: + """Compare old DDL (file or git ref) against current ORM and generate ALTER TABLE statements.""" + from rock.admin.core.schema import Base + + old_sql = _read_old_sql(source) + + old_columns = _parse_columns_from_sql(old_sql) + old_indexes = _parse_indexes_from_sql(old_sql) + + tables = Base.metadata.sorted_tables + if table_filter: + names = {n.strip() for n in table_filter.split(",")} + tables = [t for t in tables if t.name in names] + else: + table_match = re.search(r"CREATE TABLE (\w+)", old_sql) + if table_match: + target = table_match.group(1).lower() + tables = [t for t in tables if t.name == target] + + if not tables: + print("No matching table found", file=sys.stderr) + sys.exit(1) + + lines: list[str] = [] + for table in tables: + new_ddl = str(CreateTable(table).compile(dialect=dialect)).strip() + new_columns = _parse_columns_from_sql(new_ddl) + + for col_name, col_def in new_columns.items(): + if col_name not in old_columns: + lines.append(f"ALTER TABLE {table.name} ADD COLUMN {col_def};") + + for idx in sorted(table.indexes, key=lambda i: i.name or ""): + idx_name = (idx.name or "").lower() + if idx_name and idx_name not in old_indexes: + lines.append(str(CreateIndex(idx).compile(dialect=dialect)).strip() + ";") + + return "\n\n".join(lines) if lines else "-- No changes detected" + + def main() -> None: parser = argparse.ArgumentParser(description="Generate DDL from ORM schema") parser.add_argument("--dialect", default="postgresql", choices=["postgresql", "sqlite"]) + parser.add_argument("--table", default=None, help="Comma-separated table names to generate (default: all)") + parser.add_argument( + "--alter-from", + default=None, + dest="alter_from", + help="Old DDL file path or git ref (commit/tag) to diff against", + ) parser.add_argument("--out", default=None, help="Output file path (default: stdout)") args = parser.parse_args() dialect = get_dialect(args.dialect) - ddl = gen_ddl(dialect) + + if args.alter_from: + result = gen_alter(dialect, args.alter_from, table_filter=args.table) + else: + result = gen_ddl(dialect, table_filter=args.table) if args.out: with open(args.out, "w") as f: - f.write(ddl + "\n") + f.write(result + "\n") print(f"Written to {args.out}") else: - print(ddl) + print(result) if __name__ == "__main__": diff --git a/scripts/gen_statemachine_diagram.py b/scripts/gen_statemachine_diagram.py index c618c9be52..b018bdbf3f 100644 --- a/scripts/gen_statemachine_diagram.py +++ b/scripts/gen_statemachine_diagram.py @@ -4,9 +4,25 @@ uv run python scripts/gen_statemachine_diagram.py """ +import re + +import pydot from statemachine.contrib.diagram import DotGraphMachine from rock.sandbox.sandbox_statemachine import SandboxStateMachine -DotGraphMachine(SandboxStateMachine)().write_png("sandbox_statemachine.png") +INTERMEDIATE_STATES = {"pending", "archiving"} + +graph = DotGraphMachine(SandboxStateMachine)() +dot_src = graph.to_string() + +for state in INTERMEDIATE_STATES: + dot_src = re.sub( + rf"({state} \[.*?)fillcolor=white", + r'\1fillcolor="#FFF3CD", color="#856404"', + dot_src, + ) + +(styled_graph,) = pydot.graph_from_dot_data(dot_src) +styled_graph.write_png("sandbox_statemachine.png") print("Written to sandbox_statemachine.png") diff --git a/sql/sandbox_record.sql b/sql/sandbox_record.sql index b02bb565a0..103fb5eefe 100644 --- a/sql/sandbox_record.sql +++ b/sql/sandbox_record.sql @@ -16,6 +16,9 @@ CREATE TABLE sandbox_record ( cpus FLOAT, memory VARCHAR(64), create_user_gray_flag BOOLEAN, + archive_time VARCHAR(64), + state_enter_time VARCHAR(64), + delete_time VARCHAR(64), phases JSONB, port_mapping JSONB, spec JSONB, @@ -23,20 +26,20 @@ CREATE TABLE sandbox_record ( PRIMARY KEY (sandbox_id) ); -CREATE INDEX ix_sandbox_record_image ON sandbox_record (image); +CREATE INDEX ix_sandbox_record_cluster_name ON sandbox_record (cluster_name); -CREATE INDEX ix_sandbox_record_host_ip ON sandbox_record (host_ip); +CREATE INDEX ix_sandbox_record_create_user_gray_flag ON sandbox_record (create_user_gray_flag); -CREATE INDEX ix_sandbox_record_host_name ON sandbox_record (host_name); +CREATE INDEX ix_sandbox_record_experiment_id ON sandbox_record (experiment_id); -CREATE INDEX ix_sandbox_record_state ON sandbox_record (state); +CREATE INDEX ix_sandbox_record_host_ip ON sandbox_record (host_ip); -CREATE INDEX ix_sandbox_record_create_user_gray_flag ON sandbox_record (create_user_gray_flag); +CREATE INDEX ix_sandbox_record_host_name ON sandbox_record (host_name); -CREATE INDEX ix_sandbox_record_user_id ON sandbox_record (user_id); +CREATE INDEX ix_sandbox_record_image ON sandbox_record (image); CREATE INDEX ix_sandbox_record_namespace ON sandbox_record (namespace); -CREATE INDEX ix_sandbox_record_experiment_id ON sandbox_record (experiment_id); +CREATE INDEX ix_sandbox_record_state ON sandbox_record (state); -CREATE INDEX ix_sandbox_record_cluster_name ON sandbox_record (cluster_name); +CREATE INDEX ix_sandbox_record_user_id ON sandbox_record (user_id); diff --git a/tests/integration/archive/__init__.py b/tests/integration/archive/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integration/archive/test_archive_e2e.py b/tests/integration/archive/test_archive_e2e.py new file mode 100644 index 0000000000..c030288359 --- /dev/null +++ b/tests/integration/archive/test_archive_e2e.py @@ -0,0 +1,208 @@ +"""E2E integration tests for archive + restore + delete full cycle. + +Uses real MinIO + Docker Registry + a real docker container. +Does NOT require Ray — instantiates the underlying SandboxActor class directly. +""" + +import hashlib +import os +import subprocess +import tempfile +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from rock.actions.sandbox.response import State +from rock.sandbox.archive._constants import dir_archive_key, image_ref +from rock.sandbox.archive.registry_v2 import DockerRegistryV2ImageStorage +from rock.sandbox.archive.s3_storage import S3DirStorage + +pytestmark = [pytest.mark.integration] + + +@pytest.fixture +def sandbox_container(): + """Create a real docker container for archive testing.""" + container_name = f"test-archive-e2e-{os.getpid()}" + subprocess.run( + ["docker", "run", "-d", "--name", container_name, "busybox:latest", "sleep", "3600"], + check=True, + capture_output=True, + ) + yield container_name + subprocess.run(["docker", "rm", "-f", container_name], capture_output=True) + + +@pytest.fixture +def log_dir(sandbox_container): + """Create a temp log dir with known content and return (path, file_hashes).""" + base = tempfile.mkdtemp(prefix="rock-archive-e2e-logs-") + log_path = Path(base) / sandbox_container + log_path.mkdir(parents=True) + + hashes = {} + for name, content in [("app.log", b"hello world\n" * 100), ("err.log", b"error line\n")]: + p = log_path / name + p.write_bytes(content) + hashes[name] = hashlib.sha256(content).hexdigest() + + yield str(base), hashes + + import shutil + + shutil.rmtree(base, ignore_errors=True) + + +@pytest.fixture +def dir_storage(local_minio): + endpoint, access_key, secret_key, bucket = local_minio + return S3DirStorage( + endpoint=endpoint, bucket=bucket, access_key_id=access_key, access_key_secret=secret_key, region="us-east-1" + ) + + +@pytest.fixture +def image_storage(local_snapshot_registry): + registry_url, username, password = local_snapshot_registry + return DockerRegistryV2ImageStorage(registry_url=registry_url, username=username, password=password) + + +@pytest.fixture +def actor(sandbox_container): + """Create SandboxActor instance (no Ray) using the underlying class.""" + from rock.sandbox.sandbox_actor import SandboxActor + + ActorClass = SandboxActor.__ray_actor_class__ + + config = MagicMock() + config.container_name = sandbox_container + deployment = MagicMock() + deployment.restart = AsyncMock() + deployment.restart_from_image = AsyncMock() + + actor = ActorClass.__new__(ActorClass) + actor._config = config + actor._deployment = deployment + actor._clean_container_background = MagicMock() + actor._setup_monitor = AsyncMock() + return actor + + +class TestArchiveRestoreDeleteE2E: + """Full lifecycle: archive → verify exists → restore → verify restored → delete cleanup.""" + + @pytest.fixture(autouse=True) + def _patch_actor_storage(self, monkeypatch, log_dir): + """Patch actor to use local S3 instead of OSS, and set ROCK_LOGGING_PATH.""" + import rock.sandbox.archive.oss_storage as oss_mod + + monkeypatch.setattr(oss_mod, "OssDirStorage", S3DirStorage) + + import rock.env_vars as _env_vars + + log_root, _ = log_dir + monkeypatch.setattr(_env_vars, "ROCK_LOGGING_PATH", log_root) + + async def test_full_cycle(self, actor, dir_storage, image_storage, log_dir, sandbox_container, local_minio): + log_root, original_hashes = log_dir + + dir_cfg = dir_storage.client_config + img_cfg = image_storage.client_config + + # === ARCHIVE === + await actor.archive(dir_cfg, img_cfg) + + key = dir_archive_key(sandbox_container, "rock-archives/") + ref = image_ref(sandbox_container, image_storage.registry_url, "sandbox_archive") + + assert await dir_storage.exists(key), "Archive dir object should exist in MinIO" + assert await image_storage.exists(ref), "Archive image should exist in registry" + + # === SIMULATE "LOCAL IS GONE" === + log_path = Path(log_root) / sandbox_container + import shutil + + shutil.rmtree(str(log_path), ignore_errors=True) + assert not log_path.exists() + + local_tag = f"archive-staging-{sandbox_container}:latest" + subprocess.run(["docker", "rmi", local_tag], capture_output=True) + subprocess.run(["docker", "rmi", ref], capture_output=True) + + # === RESTORE === + await actor.restore_and_start(dir_cfg, img_cfg) + + assert log_path.exists(), "Log dir should be restored" + for name, expected_hash in original_hashes.items(): + restored = (log_path / name).read_bytes() + actual_hash = hashlib.sha256(restored).hexdigest() + assert actual_hash == expected_hash, f"File {name} content mismatch after restore" + + result = subprocess.run(["docker", "image", "inspect", ref], capture_output=True) + assert result.returncode == 0, "Image should be pullable locally after restore" + + # Remote still intact + assert await dir_storage.exists(key), "Dir object should still exist after restore" + assert await image_storage.exists(ref), "Image should still exist after restore" + + # === DELETE CLEANUP === + await dir_storage.delete(key) + await image_storage.delete(ref) + + assert not await dir_storage.exists(key), "Dir object should be gone after delete" + assert not await image_storage.exists(ref), "Image should be gone after delete" + + async def test_archive_rollback_on_upload_failure( + self, actor, image_storage, sandbox_container, local_minio, monkeypatch + ): + """If dir upload fails, image should be rolled back.""" + import rock.env_vars as _env_vars + + dir_cfg = { + "endpoint": "http://localhost:1", + "bucket": "nonexistent", + "access_key": "x", + "secret_key": "x", + "region": "us-east-1", + } + img_cfg = image_storage.client_config + + log_root = tempfile.mkdtemp(prefix="rock-archive-rollback-") + monkeypatch.setattr(_env_vars, "ROCK_LOGGING_PATH", log_root) + log_path = Path(log_root) / sandbox_container + log_path.mkdir(parents=True) + (log_path / "dummy.log").write_bytes(b"data") + + try: + with pytest.raises(Exception): + await actor.archive(dir_cfg, img_cfg) + + ref = image_ref(sandbox_container, image_storage.registry_url, "sandbox_archive") + assert not await image_storage.exists(ref), "Image should be cleaned up on failure" + finally: + import shutil + + shutil.rmtree(log_root, ignore_errors=True) + + +class TestIdempotentRestore: + """Verify that restoring when already restored is a no-op at the manager level.""" + + async def test_restart_from_archived_skips_if_not_archived(self): + """State machine guard: if state != ARCHIVED, skip entirely.""" + from rock.sandbox.sandbox_manager import SandboxManager + + m = MagicMock(spec=SandboxManager) + m._operator = MagicMock() + m._operator.restore_archive = AsyncMock() + + sm = AsyncMock() + sm.current_state.value = State.STOPPED + sm.sandbox_info = {"sandbox_id": "sbx-x"} + m._get_current_statemachine = AsyncMock(return_value=sm) + + m.restart_from_archived = SandboxManager.restart_from_archived.__get__(m, SandboxManager) + await m.restart_from_archived("sbx-x") + + m._operator.restore_archive.assert_not_called() diff --git a/tests/integration/archive/test_registry_v2.py b/tests/integration/archive/test_registry_v2.py new file mode 100644 index 0000000000..a097d881a8 --- /dev/null +++ b/tests/integration/archive/test_registry_v2.py @@ -0,0 +1,77 @@ +import subprocess + +import pytest + +from rock.sandbox.archive.registry_v2 import DockerRegistryV2ImageStorage + + +@pytest.fixture +def image_storage(local_snapshot_registry) -> DockerRegistryV2ImageStorage: + registry_url, username, password = local_snapshot_registry + return DockerRegistryV2ImageStorage(registry_url=registry_url, username=username, password=password) + + +@pytest.fixture(scope="session") +def test_container(): + """Create a busybox container for commit tests.""" + name = "test-archive-source" + subprocess.run(["docker", "rm", "-f", name], capture_output=True) + subprocess.run( + ["docker", "run", "-d", "--name", name, "busybox:latest", "sleep", "3600"], + check=True, + capture_output=True, + ) + yield name + subprocess.run(["docker", "rm", "-f", name], capture_output=True) + + +@pytest.fixture +def local_image(test_container): + """Commit the test container to a local image tag.""" + tag = "archive-test-local:latest" + subprocess.run(["docker", "commit", test_container, tag], check=True, capture_output=True) + yield tag + subprocess.run(["docker", "rmi", tag], capture_output=True) + + +class TestDockerRegistryV2ImageStorage: + async def test_push_then_exists(self, image_storage, local_image): + ref = f"{image_storage.registry_url}/test-ns/myimg:v1" + await image_storage.push_from_local(local_image, ref) + assert await image_storage.exists(ref) is True + await image_storage.delete(ref) + + async def test_push_pull_roundtrip(self, image_storage, local_image): + ref = f"{image_storage.registry_url}/test-ns/roundtrip:v1" + await image_storage.push_from_local(local_image, ref) + + subprocess.run(["docker", "rmi", ref], capture_output=True) + + await image_storage.pull_to_local(ref) + result = subprocess.run(["docker", "image", "inspect", ref], capture_output=True) + assert result.returncode == 0 + + subprocess.run(["docker", "rmi", ref], capture_output=True) + await image_storage.delete(ref) + + async def test_pull_nonexistent_raises(self, image_storage): + ref = f"{image_storage.registry_url}/test-ns/nosuch:v999" + with pytest.raises(RuntimeError): + await image_storage.pull_to_local(ref) + + async def test_delete_then_not_exists(self, image_storage, local_image): + ref = f"{image_storage.registry_url}/test-ns/todelete:v1" + await image_storage.push_from_local(local_image, ref) + assert await image_storage.delete(ref) is True + assert await image_storage.exists(ref) is False + + async def test_delete_nonexistent(self, image_storage): + ref = f"{image_storage.registry_url}/test-ns/nope:v1" + assert await image_storage.delete(ref) is False + + async def test_push_idempotent(self, image_storage, local_image): + ref = f"{image_storage.registry_url}/test-ns/idempotent:v1" + await image_storage.push_from_local(local_image, ref) + await image_storage.push_from_local(local_image, ref) + assert await image_storage.exists(ref) is True + await image_storage.delete(ref) diff --git a/tests/integration/archive/test_s3_storage.py b/tests/integration/archive/test_s3_storage.py new file mode 100644 index 0000000000..82462f5911 --- /dev/null +++ b/tests/integration/archive/test_s3_storage.py @@ -0,0 +1,87 @@ +import os +import tempfile + +import pytest + +from rock.sandbox.archive.s3_storage import S3DirStorage + + +@pytest.fixture +def s3_storage(local_minio): + endpoint, access_key, secret_key, bucket = local_minio + return S3DirStorage( + endpoint=endpoint, + bucket=bucket, + access_key_id=access_key, + access_key_secret=secret_key, + region="us-east-1", + ) + + +@pytest.fixture +def sample_dir(): + with tempfile.TemporaryDirectory() as td: + d = os.path.join(td, "mydata") + os.makedirs(d) + with open(os.path.join(d, "file1.txt"), "w") as f: + f.write("hello world") + os.makedirs(os.path.join(d, "subdir")) + with open(os.path.join(d, "subdir", "file2.bin"), "wb") as f: + f.write(b"\x00\x01\x02" * 100) + yield d + + +class TestS3DirStorage: + async def test_upload_exists_delete_cycle(self, s3_storage, sample_dir): + key = "test/cycle.tar.gz" + await s3_storage.upload_dir(sample_dir, key) + assert await s3_storage.exists(key) is True + + assert await s3_storage.delete(key) is True + assert await s3_storage.exists(key) is False + + async def test_delete_nonexistent(self, s3_storage): + assert await s3_storage.delete("no/such/key") is False + + async def test_upload_idempotent(self, s3_storage, sample_dir): + key = "test/idempotent.tar.gz" + await s3_storage.upload_dir(sample_dir, key) + await s3_storage.upload_dir(sample_dir, key) + assert await s3_storage.exists(key) is True + await s3_storage.delete(key) + + async def test_upload_download_roundtrip(self, s3_storage, sample_dir): + key = "test/roundtrip.tar.gz" + await s3_storage.upload_dir(sample_dir, key) + + with tempfile.TemporaryDirectory() as td: + restore_dir = os.path.join(td, "mydata") + await s3_storage.download_to_dir(key, restore_dir) + + assert os.path.isdir(restore_dir) + with open(os.path.join(restore_dir, "file1.txt")) as f: + assert f.read() == "hello world" + with open(os.path.join(restore_dir, "subdir", "file2.bin"), "rb") as f: + assert f.read() == b"\x00\x01\x02" * 100 + + await s3_storage.delete(key) + + async def test_upload_nonexistent_dir_raises(self, s3_storage): + with pytest.raises(FileNotFoundError): + await s3_storage.upload_dir("/tmp/does-not-exist-xyz", "test/nope.tar.gz") + + async def test_download_target_exists_raises(self, s3_storage, sample_dir): + key = "test/exists-check.tar.gz" + await s3_storage.upload_dir(sample_dir, key) + + with pytest.raises(FileExistsError): + await s3_storage.download_to_dir(key, sample_dir) + + await s3_storage.delete(key) + + async def test_download_missing_key_raises(self, s3_storage): + with tempfile.TemporaryDirectory() as td: + target = os.path.join(td, "output") + with pytest.raises(FileNotFoundError): + await s3_storage.download_to_dir("no/key", target) + assert not os.path.exists(target) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index b39a9a385d..a101bdce79 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -297,3 +297,126 @@ async def local_registry(): subprocess.run(["docker", "rm", "-f", container_name], capture_output=True) htpasswd_file.unlink(missing_ok=True) auth_dir.rmdir() + + +@pytest.fixture(scope="session") +def local_minio(): + """Start a local MinIO server for archive storage tests.""" + container_name = "test-minio-archive" + subprocess.run(["docker", "rm", "-f", container_name], capture_output=True) + port = run_until_complete(find_free_port()) + console_port = run_until_complete(find_free_port()) + subprocess.run( + [ + "docker", + "run", + "-d", + "--name", + container_name, + "-p", + f"{port}:9000", + "-p", + f"{console_port}:9001", + "-e", + "MINIO_ROOT_USER=rockadmin", + "-e", + "MINIO_ROOT_PASSWORD=rockadmin123", + "minio/minio:RELEASE.2024-12-18T13-15-44Z", + "server", + "/data", + "--console-address", + ":9001", + ], + check=True, + ) + + endpoint = f"http://localhost:{port}" + for _ in range(30): + try: + urllib.request.urlopen(f"{endpoint}/minio/health/live", timeout=1) + break + except (urllib.error.URLError, ConnectionError, OSError): + time.sleep(0.5) + else: + raise RuntimeError(f"MinIO at {endpoint} did not become ready") + + import boto3 + + s3 = boto3.client( + "s3", + endpoint_url=endpoint, + aws_access_key_id="rockadmin", + aws_secret_access_key="rockadmin123", + region_name="us-east-1", + ) + s3.create_bucket(Bucket="rock-archive-test") + + yield endpoint, "rockadmin", "rockadmin123", "rock-archive-test" + + subprocess.run(["docker", "rm", "-f", container_name], capture_output=True) + + +@pytest.fixture(scope="session") +def local_snapshot_registry(): + """Start a local Docker registry with basic auth and delete support.""" + container_name = "test-snapshot-registry" + username, password = "testuser", "testpass" + subprocess.run(["docker", "rm", "-f", container_name], capture_output=True) + port = run_until_complete(find_free_port()) + + auth_dir = tempfile.mkdtemp() + htpasswd_file = os.path.join(auth_dir, "htpasswd") + result = subprocess.run( + ["docker", "run", "--rm", "httpd:2", "htpasswd", "-Bbn", username, password], + capture_output=True, + text=True, + ) + with open(htpasswd_file, "w") as f: + f.write(result.stdout) + + subprocess.run( + [ + "docker", + "run", + "-d", + "--name", + container_name, + "-p", + f"{port}:5000", + "-v", + f"{htpasswd_file}:/auth/htpasswd", + "-e", + "REGISTRY_AUTH=htpasswd", + "-e", + "REGISTRY_AUTH_HTPASSWD_REALM=Registry Realm", + "-e", + "REGISTRY_AUTH_HTPASSWD_PATH=/auth/htpasswd", + "-e", + "REGISTRY_STORAGE_DELETE_ENABLED=true", + "registry:2", + ], + check=True, + ) + + registry_url = f"localhost:{port}" + for _ in range(30): + try: + req = urllib.request.Request(f"http://{registry_url}/v2/") + import base64 + + credentials = base64.b64encode(f"{username}:{password}".encode()).decode() + req.add_header("Authorization", f"Basic {credentials}") + r = urllib.request.urlopen(req, timeout=1) + if r.status == 200: + break + except (urllib.error.URLError, ConnectionError, OSError): + time.sleep(0.5) + else: + raise RuntimeError(f"Registry at {registry_url} did not become ready") + + yield registry_url, username, password + + subprocess.run(["docker", "rm", "-f", container_name], capture_output=True) + import shutil + + shutil.rmtree(auth_dir, ignore_errors=True) diff --git a/tests/unit/sandbox/__init__.py b/tests/unit/sandbox/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sandbox/archive/__init__.py b/tests/unit/sandbox/archive/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sandbox/archive/test_archive_not_configured.py b/tests/unit/sandbox/archive/test_archive_not_configured.py new file mode 100644 index 0000000000..3d9f1913c6 --- /dev/null +++ b/tests/unit/sandbox/archive/test_archive_not_configured.py @@ -0,0 +1,70 @@ +"""Tests: archive methods must return errors / skip gracefully when archive is not configured.""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from rock.config import SandboxLifecycleConfig +from rock.sdk.common.exceptions import BadRequestRockError + + +@pytest.fixture +def manager_no_archive(): + """SandboxManager-like mock WITHOUT _dir_storage / _image_storage.""" + from rock.sandbox.sandbox_manager import SandboxManager + + m = MagicMock(spec=SandboxManager) + m.rock_config.lifecycle = SandboxLifecycleConfig() + m._meta_store = AsyncMock() + m._operator = MagicMock() + m._operator.supports_archive.return_value = True + + m._dir_storage = None + m._image_storage = None + + m.archive_sandbox = SandboxManager.archive_sandbox.__get__(m, SandboxManager) + m.restart_from_archived = SandboxManager.restart_from_archived.__get__(m, SandboxManager) + m._reconcile_archiving = SandboxManager._reconcile_archiving.__get__(m, SandboxManager) + m._auto_archive_stopped = SandboxManager._auto_archive_stopped.__get__(m, SandboxManager) + return m + + +class TestArchiveNotConfigured: + async def test_archive_sandbox_raises_error(self, manager_no_archive): + with pytest.raises(BadRequestRockError, match="archive not configured"): + await manager_no_archive.archive_sandbox("sbx-1") + + async def test_restart_from_archived_raises_error(self, manager_no_archive): + with pytest.raises(BadRequestRockError, match="archive not configured"): + await manager_no_archive.restart_from_archived("sbx-1") + + async def test_reconcile_archiving_skips(self, manager_no_archive): + await manager_no_archive._reconcile_archiving() + manager_no_archive._meta_store.list_by.assert_not_called() + + async def test_auto_archive_stopped_skips(self, manager_no_archive): + manager_no_archive.rock_config.lifecycle.auto_archive_after_sec = 3600 + await manager_no_archive._auto_archive_stopped() + manager_no_archive._meta_store.list_by.assert_not_called() + + +class TestArchiveOperatorNotSupported: + async def test_archive_sandbox_raises_error(self, manager_no_archive): + manager_no_archive._dir_storage = AsyncMock() + manager_no_archive._image_storage = AsyncMock() + manager_no_archive._operator.supports_archive.return_value = False + with pytest.raises(BadRequestRockError, match="archive not supported"): + await manager_no_archive.archive_sandbox("sbx-1") + + async def test_reconcile_archiving_skips(self, manager_no_archive): + manager_no_archive._operator.supports_archive.return_value = False + await manager_no_archive._reconcile_archiving() + manager_no_archive._meta_store.list_by.assert_not_called() + + async def test_auto_archive_stopped_skips(self, manager_no_archive): + manager_no_archive._dir_storage = AsyncMock() + manager_no_archive._image_storage = AsyncMock() + manager_no_archive._operator.supports_archive.return_value = False + manager_no_archive.rock_config.lifecycle.auto_archive_after_sec = 3600 + await manager_no_archive._auto_archive_stopped() + manager_no_archive._meta_store.list_by.assert_not_called() diff --git a/tests/unit/sandbox/archive/test_archive_sandbox_orchestration.py b/tests/unit/sandbox/archive/test_archive_sandbox_orchestration.py new file mode 100644 index 0000000000..5dfbaffed1 --- /dev/null +++ b/tests/unit/sandbox/archive/test_archive_sandbox_orchestration.py @@ -0,0 +1,139 @@ +"""Unit tests for SandboxManager.archive_sandbox and restart_from_archived.""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from rock.actions.sandbox.response import State +from rock.sdk.common.exceptions import BadRequestRockError + + +@pytest.fixture +def manager(): + from rock.sandbox.sandbox_manager import SandboxManager + + m = MagicMock(spec=SandboxManager) + m._meta_store = AsyncMock() + m._meta_store._db = AsyncMock() + m._meta_store._db.get = AsyncMock( + return_value={"sandbox_id": "sbx-1", "spec": {"container_name": "sbx-1", "image": "img:latest"}} + ) + m._operator = MagicMock() + m._operator.supports_archive.return_value = True + m._operator.start_archive = AsyncMock() + m._operator.start_restore = AsyncMock() + m._dir_storage = AsyncMock() + m._dir_storage.client_config = {"endpoint": "e", "bucket": "b", "access_key": "a", "secret_key": "s", "region": "r"} + m._dir_storage.delete = AsyncMock(return_value=True) + m._image_storage = AsyncMock() + m._image_storage.registry_url = "localhost:5000" + m._image_storage.client_config = {"registry_url": "localhost:5000"} + m._image_storage.delete = AsyncMock(return_value=True) + + from rock.config import ArchiveConfig + + m.rock_config = MagicMock() + m.rock_config.lifecycle.archive = ArchiveConfig() + + m.archive_sandbox = SandboxManager.archive_sandbox.__get__(m, SandboxManager) + m.restart_from_archived = SandboxManager.restart_from_archived.__get__(m, SandboxManager) + return m + + +@pytest.fixture +def sm_stopped(): + sm = AsyncMock() + sm.current_state.value = State.STOPPED + sm.sandbox_info = { + "sandbox_id": "sbx-1", + "state": State.STOPPED, + "host_ip": "10.0.0.1", + "spec": {"container_name": "sbx-1", "image": "img:latest"}, + } + return sm + + +@pytest.fixture +def sm_archived(): + sm = AsyncMock() + sm.current_state.value = State.ARCHIVED + sm.sandbox_info = { + "sandbox_id": "sbx-1", + "state": State.ARCHIVED, + "archive_time": "2026-01-01T000000Z", + "host_ip": "10.0.0.1", + "spec": {"container_name": "sbx-1", "image": "img:latest"}, + } + return sm + + +class TestArchiveSandbox: + async def test_happy_path(self, manager, sm_stopped): + manager._get_current_statemachine = AsyncMock(return_value=sm_stopped) + await manager.archive_sandbox("sbx-1") + + sm_stopped.send.assert_called_once() + assert sm_stopped.send.call_args[0][0] == "archive" + kwargs = sm_stopped.send.call_args[1] + assert kwargs["operator"] is manager._operator + assert kwargs["dir_storage"] is manager._dir_storage + assert kwargs["image_storage"] is manager._image_storage + assert "archive_params" in kwargs + + async def test_unsupported_operator_raises(self, manager, sm_stopped): + manager._operator.supports_archive.return_value = False + with pytest.raises(BadRequestRockError): + await manager.archive_sandbox("sbx-1") + + async def test_not_found_raises(self, manager): + manager._get_current_statemachine = AsyncMock(return_value=None) + with pytest.raises(BadRequestRockError): + await manager.archive_sandbox("sbx-1") + + async def test_passes_storage(self, manager, sm_stopped): + """Verify storage is passed through to on_archive for operator use.""" + manager._get_current_statemachine = AsyncMock(return_value=sm_stopped) + + await manager.archive_sandbox("sbx-1") + + sm_stopped.send.assert_called_once() + kwargs = sm_stopped.send.call_args[1] + assert kwargs["dir_storage"] is manager._dir_storage + assert kwargs["image_storage"] is manager._image_storage + + +class TestRestartFromArchived: + async def test_happy_path(self, manager, sm_archived): + manager._get_current_statemachine = AsyncMock(return_value=sm_archived) + await manager.restart_from_archived("sbx-1") + + sm_archived.send.assert_called_once() + assert sm_archived.send.call_args[0][0] == "restore" + kwargs = sm_archived.send.call_args[1] + assert kwargs["operator"] is manager._operator + assert kwargs["dir_storage"] is manager._dir_storage + assert kwargs["image_storage"] is manager._image_storage + + async def test_not_archived_skips(self, manager, sm_stopped): + manager._get_current_statemachine = AsyncMock(return_value=sm_stopped) + await manager.restart_from_archived("sbx-1") + + manager._operator.start_restore.assert_not_called() + sm_stopped.send.assert_not_called() + + async def test_not_found_raises(self, manager): + manager._get_current_statemachine = AsyncMock(return_value=None) + with pytest.raises(BadRequestRockError): + await manager.restart_from_archived("sbx-1") + + async def test_no_spec_raises(self, manager, sm_archived): + sm_archived.sandbox_info = { + "sandbox_id": "sbx-1", + "state": State.ARCHIVED, + "archive_time": "2026-01-01T000000Z", + "host_ip": "10.0.0.1", + "spec": {}, + } + manager._get_current_statemachine = AsyncMock(return_value=sm_archived) + with pytest.raises(BadRequestRockError): + await manager.restart_from_archived("sbx-1") diff --git a/tests/unit/sandbox/archive/test_auto_delete_stopped.py b/tests/unit/sandbox/archive/test_auto_delete_stopped.py new file mode 100644 index 0000000000..88ebc5f59f --- /dev/null +++ b/tests/unit/sandbox/archive/test_auto_delete_stopped.py @@ -0,0 +1,70 @@ +"""Unit tests for SandboxManager._auto_delete_stopped.""" + +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from rock.common.constants import DeleteReason +from rock.config import SandboxLifecycleConfig + + +@pytest.fixture +def manager(): + from rock.sandbox.sandbox_manager import SandboxManager + + m = MagicMock(spec=SandboxManager) + m._meta_store = AsyncMock() + m.rock_config = MagicMock() + m.rock_config.lifecycle = SandboxLifecycleConfig(auto_delete_after_sec=3600) + m.delete = AsyncMock() + m._auto_delete_stopped = SandboxManager._auto_delete_stopped.__get__(m, SandboxManager) + return m + + +class TestAutoDeleteStopped: + async def test_disabled_when_zero(self, manager): + manager.rock_config.lifecycle.auto_delete_after_sec = 0 + result = await manager._auto_delete_stopped() + assert result == set() + manager._meta_store.list_by.assert_not_called() + + async def test_skips_sandbox_within_threshold(self, manager): + now = datetime.now(timezone.utc) + recent_stop = (now - timedelta(seconds=60)).isoformat() + manager._meta_store.list_by = AsyncMock(return_value=[{"sandbox_id": "sbx-1", "stop_time": recent_stop}]) + result = await manager._auto_delete_stopped() + assert result == set() + manager.delete.assert_not_called() + + async def test_deletes_sandbox_past_threshold(self, manager): + now = datetime.now(timezone.utc) + old_stop = (now - timedelta(seconds=7200)).isoformat() + manager._meta_store.list_by = AsyncMock(return_value=[{"sandbox_id": "sbx-1", "stop_time": old_stop}]) + result = await manager._auto_delete_stopped() + assert "sbx-1" in result + manager.delete.assert_awaited_once_with("sbx-1", reason=DeleteReason.EXPIRED) + + async def test_delete_failure_does_not_propagate(self, manager): + now = datetime.now(timezone.utc) + old_stop = (now - timedelta(seconds=7200)).isoformat() + manager._meta_store.list_by = AsyncMock(return_value=[{"sandbox_id": "sbx-1", "stop_time": old_stop}]) + manager.delete = AsyncMock(side_effect=RuntimeError("delete failed")) + result = await manager._auto_delete_stopped() + assert result == set() + + async def test_empty_list_returns_empty(self, manager): + manager._meta_store.list_by = AsyncMock(return_value=[]) + result = await manager._auto_delete_stopped() + assert result == set() + + async def test_missing_stop_time_is_skipped(self, manager): + manager._meta_store.list_by = AsyncMock(return_value=[{"sandbox_id": "sbx-1", "stop_time": ""}]) + result = await manager._auto_delete_stopped() + assert result == set() + manager.delete.assert_not_called() + + async def test_list_by_failure_returns_empty(self, manager): + manager._meta_store.list_by = AsyncMock(side_effect=RuntimeError("db error")) + result = await manager._auto_delete_stopped() + assert result == set() diff --git a/tests/unit/sandbox/archive/test_check_archive_progress.py b/tests/unit/sandbox/archive/test_check_archive_progress.py new file mode 100644 index 0000000000..856db4084f --- /dev/null +++ b/tests/unit/sandbox/archive/test_check_archive_progress.py @@ -0,0 +1,103 @@ +"""Unit tests for SandboxManager._reconcile_archiving scanner.""" + +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from rock.actions.sandbox.response import State +from rock.config import SandboxLifecycleConfig + + +@pytest.fixture +def manager(): + """Create a minimal SandboxManager-like object with mocked deps.""" + from rock.sandbox.sandbox_manager import SandboxManager + + m = MagicMock(spec=SandboxManager) + m.rock_config.lifecycle = SandboxLifecycleConfig() + m._meta_store = AsyncMock() + m._operator = MagicMock() + m._operator.supports_archive.return_value = True + m._dir_storage = AsyncMock() + m._dir_storage.client_config = { + "endpoint": "http://localhost:9000", + "bucket": "b", + "access_key": "a", + "secret_key": "s", + "region": "r", + } + m._image_storage = AsyncMock() + m._image_storage.registry_url = "localhost:5000" + m._image_storage.client_config = {"registry_url": "localhost:5000"} + m._get_current_statemachine = AsyncMock() + m._reconcile_archiving = SandboxManager._reconcile_archiving.__get__(m, SandboxManager) + return m + + +class TestCheckArchiveProgress: + async def test_empty_list_does_nothing(self, manager): + manager._meta_store.list_by = AsyncMock(return_value=[]) + await manager._reconcile_archiving() + manager._image_storage.exists.assert_not_called() + + async def test_image_exists_triggers_archive_done(self, manager): + info = { + "sandbox_id": "sbx-1", + "archive_time": "2026-01-01T000000Z", + "state_enter_time": "2026-01-01T000000Z", + "state": State.ARCHIVING, + } + manager._meta_store.list_by = AsyncMock(return_value=[info]) + manager._image_storage.exists = AsyncMock(return_value=True) + + sm_mock = AsyncMock() + sm_mock.current_state.value = State.ARCHIVING + manager._get_current_statemachine = AsyncMock(return_value=sm_mock) + + await manager._reconcile_archiving() + + sm_mock.send.assert_called_once() + call_kwargs = sm_mock.send.call_args + assert call_kwargs[0][0] == "archive_done" + + async def test_image_not_exist_within_timeout_skips(self, manager): + now = datetime.now(timezone.utc).isoformat() + info = { + "sandbox_id": "sbx-1", + "archive_time": "2026-01-01T000000Z", + "state_enter_time": now, + "state": State.ARCHIVING, + } + manager._meta_store.list_by = AsyncMock(return_value=[info]) + manager._image_storage.exists = AsyncMock(return_value=False) + + await manager._reconcile_archiving() + + manager._get_current_statemachine.assert_not_called() + + async def test_timeout_triggers_archive_failed(self, manager): + old_time = "2020-01-01T000000Z" + info = { + "sandbox_id": "sbx-1", + "archive_time": "t1", + "state_enter_time": old_time, + "state": State.ARCHIVING, + } + manager._meta_store.list_by = AsyncMock(return_value=[info]) + manager._image_storage.exists = AsyncMock(return_value=False) + + sm_mock = AsyncMock() + sm_mock.current_state.value = State.ARCHIVING + manager._get_current_statemachine = AsyncMock(return_value=sm_mock) + + await manager._reconcile_archiving() + + sm_mock.send.assert_called_once() + assert sm_mock.send.call_args[0][0] == "archive_failed" + + async def test_operator_not_supporting_archive_returns_early(self, manager): + manager._operator.supports_archive.return_value = False + manager._meta_store.list_by = AsyncMock() + await manager._reconcile_archiving() + manager._meta_store.list_by.assert_not_called() diff --git a/tests/unit/sandbox/archive/test_delete_clears_archive.py b/tests/unit/sandbox/archive/test_delete_clears_archive.py new file mode 100644 index 0000000000..5d8b195bbb --- /dev/null +++ b/tests/unit/sandbox/archive/test_delete_clears_archive.py @@ -0,0 +1,105 @@ +"""Unit tests for delete() passing dir_storage/image_storage to on_delete for archive cleanup.""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from rock.actions.sandbox.response import State +from rock.sdk.common.exceptions import BadRequestRockError + + +@pytest.fixture +def manager(): + from rock.sandbox.sandbox_manager import SandboxManager + + m = MagicMock(spec=SandboxManager) + m._meta_store = AsyncMock() + m._operator = MagicMock() + m._operator.supports_archive.return_value = True + m._dir_storage = AsyncMock() + m._dir_storage.client_config = {"endpoint": "e", "bucket": "b", "access_key": "a", "secret_key": "s", "region": "r"} + m._dir_storage.delete = AsyncMock(return_value=True) + m._image_storage = AsyncMock() + m._image_storage.registry_url = "localhost:5000" + m._image_storage.client_config = {"registry_url": "localhost:5000"} + m._image_storage.delete = AsyncMock(return_value=True) + + m.delete = SandboxManager.delete.__get__(m, SandboxManager) + return m + + +@pytest.fixture +def sm_archived(): + sm = AsyncMock() + sm.current_state.value = State.ARCHIVED + sm.sandbox_info = { + "sandbox_id": "sbx-1", + "state": State.ARCHIVED, + "archive_time": "2026-01-01T000000Z", + "host_ip": "10.0.0.1", + "spec": {"container_name": "sbx-1", "image": "img:latest"}, + } + return sm + + +@pytest.fixture +def sm_stopped(): + sm = AsyncMock() + sm.current_state.value = State.STOPPED + sm.sandbox_info = { + "sandbox_id": "sbx-1", + "state": State.STOPPED, + "host_ip": "10.0.0.1", + "spec": {"container_name": "sbx-1", "image": "img:latest"}, + } + return sm + + +class TestDeletePassesStorageToCallback: + async def test_delete_archived_passes_storage(self, manager, sm_archived): + manager._get_current_statemachine = AsyncMock(return_value=sm_archived) + await manager.delete("sbx-1") + + sm_archived.send.assert_called_once() + assert sm_archived.send.call_args[0][0] == "delete" + kwargs = sm_archived.send.call_args[1] + assert kwargs["dir_storage"] is manager._dir_storage + assert kwargs["image_storage"] is manager._image_storage + + async def test_delete_stopped_passes_storage(self, manager, sm_stopped): + manager._get_current_statemachine = AsyncMock(return_value=sm_stopped) + await manager.delete("sbx-1") + + sm_stopped.send.assert_called_once() + kwargs = sm_stopped.send.call_args[1] + assert kwargs["dir_storage"] is manager._dir_storage + assert kwargs["image_storage"] is manager._image_storage + + async def test_delete_without_storage_passes_none(self, manager, sm_stopped): + manager._dir_storage = None + manager._image_storage = None + manager._get_current_statemachine = AsyncMock(return_value=sm_stopped) + await manager.delete("sbx-1") + + sm_stopped.send.assert_called_once() + kwargs = sm_stopped.send.call_args[1] + assert kwargs["dir_storage"] is None + assert kwargs["image_storage"] is None + + async def test_delete_running_still_rejected(self, manager): + sm = AsyncMock() + sm.current_state.value = State.RUNNING + sm.sandbox_info = {"sandbox_id": "sbx-1"} + manager._get_current_statemachine = AsyncMock(return_value=sm) + + with pytest.raises(BadRequestRockError, match="stopped or archived"): + await manager.delete("sbx-1") + + async def test_delete_already_deleted_noop(self, manager): + sm = AsyncMock() + sm.current_state.value = State.DELETED + manager._get_current_statemachine = AsyncMock(return_value=sm) + + await manager.delete("sbx-1") + + sm.send.assert_not_called() diff --git a/tests/unit/sandbox/archive/test_image_storage_interface.py b/tests/unit/sandbox/archive/test_image_storage_interface.py new file mode 100644 index 0000000000..4786678ace --- /dev/null +++ b/tests/unit/sandbox/archive/test_image_storage_interface.py @@ -0,0 +1,115 @@ +import pytest + +from rock.sandbox.archive.abstract import AbstractImageStorage + + +class TestAbstractImageStorageCannotInstantiate: + def test_cannot_instantiate(self): + with pytest.raises(TypeError): + AbstractImageStorage() + + +class InMemoryFakeImageStorage(AbstractImageStorage): + """In-memory fake for testing callers of AbstractImageStorage.""" + + def __init__(self, registry_url: str = "localhost:5000"): + self._registry_url = registry_url + self._store: dict[str, bytes] = {} + + @property + def registry_url(self) -> str: + return self._registry_url + + async def push_from_local(self, local_image_tag: str, remote_image_ref: str) -> None: + self._store[remote_image_ref] = b"fake-manifest" + + async def pull_to_local(self, remote_image_ref: str) -> None: + if remote_image_ref not in self._store: + raise RuntimeError(f"image not found: {remote_image_ref}") + + async def delete(self, image_ref: str) -> bool: + if image_ref in self._store: + del self._store[image_ref] + return True + return False + + async def exists(self, image_ref: str) -> bool: + return image_ref in self._store + + +class TestInMemoryFakeImageStorage: + @pytest.fixture + def storage(self): + return InMemoryFakeImageStorage() + + async def test_push_then_exists(self, storage): + await storage.push_from_local("local:tag", "localhost:5000/foo:bar") + assert await storage.exists("localhost:5000/foo:bar") is True + + async def test_pull_nonexistent_raises(self, storage): + with pytest.raises(RuntimeError): + await storage.pull_to_local("localhost:5000/no:thing") + + async def test_delete(self, storage): + await storage.push_from_local("local:tag", "localhost:5000/foo:bar") + assert await storage.delete("localhost:5000/foo:bar") is True + assert await storage.exists("localhost:5000/foo:bar") is False + + async def test_delete_nonexistent(self, storage): + assert await storage.delete("nope") is False + + async def test_registry_url_property(self, storage): + assert storage.registry_url == "localhost:5000" + + +class TestDockerRegistryV2ImageStorageWithAuth: + """Test DockerRegistryV2ImageStorage with username/password (authenticated mode).""" + + def test_instantiates_with_auth(self): + from rock.sandbox.archive.registry_v2 import DockerRegistryV2ImageStorage + + s = DockerRegistryV2ImageStorage( + registry_url="rock-registry.cn-shanghai.cr.aliyuncs.com", + username="user", + password="pass", + ) + assert isinstance(s, AbstractImageStorage) + + def test_client_config_with_auth(self): + from rock.sandbox.archive.registry_v2 import DockerRegistryV2ImageStorage + + s = DockerRegistryV2ImageStorage( + registry_url="rock-registry.cn-shanghai.cr.aliyuncs.com", + username="user", + password="pass", + ) + cfg = s.client_config + assert cfg["registry_url"] == "rock-registry.cn-shanghai.cr.aliyuncs.com" + assert cfg["username"] == "user" + assert cfg["password"] == "pass" + + def test_registry_url_property(self): + from rock.sandbox.archive.registry_v2 import DockerRegistryV2ImageStorage + + s = DockerRegistryV2ImageStorage( + registry_url="rock-registry.cn-shanghai.cr.aliyuncs.com", + username="user", + password="pass", + ) + assert s.registry_url == "rock-registry.cn-shanghai.cr.aliyuncs.com" + + async def test_delete_returns_false_when_challenge_fails(self): + from unittest.mock import AsyncMock, MagicMock, patch + + from rock.sandbox.archive.registry_v2 import DockerRegistryV2ImageStorage + + s = DockerRegistryV2ImageStorage(registry_url="r.example.com", username="u", password="p") + mock_response = MagicMock() + mock_response.headers = {} + mock_response.status_code = 401 + mock_client = AsyncMock() + mock_client.get = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + with patch("httpx.AsyncClient", return_value=mock_client): + assert await s.delete("r.example.com/ns/repo:tag") is False diff --git a/tests/unit/sandbox/archive/test_sandbox_actor_archive.py b/tests/unit/sandbox/archive/test_sandbox_actor_archive.py new file mode 100644 index 0000000000..59e5b67e77 --- /dev/null +++ b/tests/unit/sandbox/archive/test_sandbox_actor_archive.py @@ -0,0 +1,249 @@ +"""Unit tests for SandboxActor.archive() method (mocked, no Ray).""" + +import os +import tempfile +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +@pytest.fixture +def actor(): + """Create a SandboxActor instance without Ray, using the underlying class.""" + from rock.sandbox.sandbox_actor import SandboxActor + + ActorClass = SandboxActor.__ray_actor_class__ + + config = MagicMock() + config.container_name = "sbx-test-123" + deployment = MagicMock() + + instance = ActorClass.__new__(ActorClass) + instance._config = config + instance._deployment = deployment + instance._run_shell_command = AsyncMock() + return instance + + +@pytest.fixture +def dir_storage_config(): + return { + "endpoint": "https://oss-cn-hangzhou.aliyuncs.com", + "bucket": "test-bucket", + "access_key_id": "test-ak", + "access_key_secret": "test-sk", + "region": "cn-hangzhou", + } + + +@pytest.fixture +def image_storage_config(): + return { + "registry_url": "rock-registry.cn-shanghai.cr.aliyuncs.com", + "username": "test-user", + "password": "test-pass", + } + + +class TestSandboxActorArchive: + @patch("rock.sandbox.archive.oss_storage.OssDirStorage.upload_dir", new_callable=AsyncMock) + @patch("rock.sandbox.archive.registry_v2.DockerRegistryV2ImageStorage.push_from_local", new_callable=AsyncMock) + async def test_happy_path( + self, mock_push, mock_upload, actor, dir_storage_config, image_storage_config, monkeypatch + ): + import rock.env_vars as _env_vars + + with tempfile.TemporaryDirectory() as tmpdir: + monkeypatch.setattr(_env_vars, "ROCK_LOGGING_PATH", tmpdir) + log_dir = os.path.join(tmpdir, "sbx-test-123") + os.makedirs(log_dir) + + await actor.archive(dir_storage_config, image_storage_config) + + actor._run_shell_command.assert_any_call( + "docker", "commit", "sbx-test-123", "archive-staging-sbx-test-123:latest" + ) + mock_push.assert_called_once() + mock_upload.assert_called_once() + call_args = mock_upload.call_args + assert "sbx-test-123" in call_args[0][0] + assert "sbx-test-123" in call_args[0][1] + + @patch("rock.sandbox.archive.oss_storage.OssDirStorage.upload_dir", new_callable=AsyncMock) + @patch( + "rock.sandbox.archive.registry_v2.DockerRegistryV2ImageStorage.push_from_local", + new_callable=AsyncMock, + side_effect=RuntimeError("push failed"), + ) + async def test_push_failure_cleans_local_tag( + self, mock_push, mock_upload, actor, dir_storage_config, image_storage_config, monkeypatch + ): + import rock.env_vars as _env_vars + + monkeypatch.setattr(_env_vars, "ROCK_LOGGING_PATH", "/tmp/logs") + + with pytest.raises(RuntimeError, match="push failed"): + await actor.archive(dir_storage_config, image_storage_config) + + actor._run_shell_command.assert_any_call("docker", "rmi", "archive-staging-sbx-test-123:latest", check=False) + mock_upload.assert_not_called() + + @patch( + "rock.sandbox.archive.oss_storage.OssDirStorage.upload_dir", + new_callable=AsyncMock, + side_effect=RuntimeError("upload failed"), + ) + @patch("rock.sandbox.archive.registry_v2.DockerRegistryV2ImageStorage.push_from_local", new_callable=AsyncMock) + @patch("rock.sandbox.archive.registry_v2.DockerRegistryV2ImageStorage.delete", new_callable=AsyncMock) + async def test_upload_failure_rolls_back_image( + self, mock_delete_img, mock_push, mock_upload, actor, dir_storage_config, image_storage_config, monkeypatch + ): + import rock.env_vars as _env_vars + + with tempfile.TemporaryDirectory() as tmpdir: + monkeypatch.setattr(_env_vars, "ROCK_LOGGING_PATH", tmpdir) + log_dir = os.path.join(tmpdir, "sbx-test-123") + os.makedirs(log_dir) + + with pytest.raises(RuntimeError, match="upload failed"): + await actor.archive(dir_storage_config, image_storage_config) + + mock_push.assert_called_once() + mock_delete_img.assert_called_once() + ref_arg = mock_delete_img.call_args[0][0] + assert "sbx-test-123" in ref_arg + + @patch("rock.sandbox.archive.registry_v2.DockerRegistryV2ImageStorage.push_from_local", new_callable=AsyncMock) + async def test_no_logging_path_skips_log_archive( + self, mock_push, actor, dir_storage_config, image_storage_config, monkeypatch + ): + import rock.env_vars as _env_vars + + monkeypatch.setattr(_env_vars, "ROCK_LOGGING_PATH", None) + + await actor.archive(dir_storage_config, image_storage_config) + + mock_push.assert_called_once() + + @patch("rock.sandbox.archive.oss_storage.OssDirStorage.upload_dir", new_callable=AsyncMock) + @patch("rock.sandbox.archive.registry_v2.DockerRegistryV2ImageStorage.push_from_local", new_callable=AsyncMock) + async def test_missing_log_dir_skips_upload( + self, mock_push, mock_upload, actor, dir_storage_config, image_storage_config, monkeypatch + ): + import rock.env_vars as _env_vars + + monkeypatch.setattr(_env_vars, "ROCK_LOGGING_PATH", "/nonexistent") + + await actor.archive(dir_storage_config, image_storage_config) + + mock_push.assert_called_once() + mock_upload.assert_not_called() + + +class TestArchiveSizeLimits: + """Tests for archive size limit enforcement.""" + + @patch("rock.sandbox.archive.registry_v2.DockerRegistryV2ImageStorage.push_from_local", new_callable=AsyncMock) + async def test_image_exceeds_limit_raises( + self, mock_push, actor, dir_storage_config, image_storage_config, monkeypatch + ): + import subprocess + + import rock.env_vars as _env_vars + + monkeypatch.setattr(_env_vars, "ROCK_LOGGING_PATH", None) + + inspect_result = subprocess.CompletedProcess(args=(), returncode=0, stdout=b"20000000000", stderr=b"") + actor._run_shell_command = AsyncMock(return_value=inspect_result) + + limits = {"max_image_push_size": "16g", "max_dir_upload_size": "16g"} + + with pytest.raises(RuntimeError, match="exceeds limit"): + await actor.archive(dir_storage_config, image_storage_config, limits) + + mock_push.assert_not_called() + + @patch("rock.sandbox.archive.oss_storage.OssDirStorage.upload_dir", new_callable=AsyncMock) + @patch("rock.sandbox.archive.registry_v2.DockerRegistryV2ImageStorage.push_from_local", new_callable=AsyncMock) + async def test_image_within_limit_proceeds( + self, mock_push, mock_upload, actor, dir_storage_config, image_storage_config, monkeypatch + ): + import subprocess + + import rock.env_vars as _env_vars + + monkeypatch.setattr(_env_vars, "ROCK_LOGGING_PATH", None) + + inspect_result = subprocess.CompletedProcess(args=(), returncode=0, stdout=b"1000000000", stderr=b"") + actor._run_shell_command = AsyncMock(return_value=inspect_result) + + limits = {"max_image_push_size": "16g", "max_dir_upload_size": "16g"} + await actor.archive(dir_storage_config, image_storage_config, limits) + + mock_push.assert_called_once() + + @patch("rock.sandbox.archive.oss_storage.OssDirStorage.upload_dir", new_callable=AsyncMock) + @patch("rock.sandbox.archive.registry_v2.DockerRegistryV2ImageStorage.push_from_local", new_callable=AsyncMock) + async def test_dir_exceeds_limit_raises( + self, mock_push, mock_upload, actor, dir_storage_config, image_storage_config, monkeypatch + ): + import subprocess + + import rock.env_vars as _env_vars + + with tempfile.TemporaryDirectory() as tmpdir: + monkeypatch.setattr(_env_vars, "ROCK_LOGGING_PATH", tmpdir) + log_dir = os.path.join(tmpdir, "sbx-test-123") + os.makedirs(log_dir) + + image_result = subprocess.CompletedProcess(args=(), returncode=0, stdout=b"1000000000", stderr=b"") + du_result = subprocess.CompletedProcess(args=(), returncode=0, stdout=b"20000000000\t/tmp/logs", stderr=b"") + + async def side_effect(*args, **kwargs): + if args[0] == "du": + return du_result + return image_result + + actor._run_shell_command = AsyncMock(side_effect=side_effect) + + limits = {"max_image_push_size": "16g", "max_dir_upload_size": "16g"} + with pytest.raises(RuntimeError, match="exceeds limit"): + await actor.archive(dir_storage_config, image_storage_config, limits) + + mock_push.assert_called_once() + mock_upload.assert_not_called() + + @patch("rock.sandbox.archive.oss_storage.OssDirStorage.upload_dir", new_callable=AsyncMock) + @patch("rock.sandbox.archive.registry_v2.DockerRegistryV2ImageStorage.push_from_local", new_callable=AsyncMock) + async def test_no_limits_skips_checks( + self, mock_push, mock_upload, actor, dir_storage_config, image_storage_config, monkeypatch + ): + import rock.env_vars as _env_vars + + with tempfile.TemporaryDirectory() as tmpdir: + monkeypatch.setattr(_env_vars, "ROCK_LOGGING_PATH", tmpdir) + log_dir = os.path.join(tmpdir, "sbx-test-123") + os.makedirs(log_dir) + + await actor.archive(dir_storage_config, image_storage_config) + + mock_push.assert_called_once() + mock_upload.assert_called_once() + + @patch("rock.sandbox.archive.oss_storage.OssDirStorage.upload_dir", new_callable=AsyncMock) + @patch("rock.sandbox.archive.registry_v2.DockerRegistryV2ImageStorage.push_from_local", new_callable=AsyncMock) + async def test_empty_limit_string_skips_check( + self, mock_push, mock_upload, actor, dir_storage_config, image_storage_config, monkeypatch + ): + import rock.env_vars as _env_vars + + with tempfile.TemporaryDirectory() as tmpdir: + monkeypatch.setattr(_env_vars, "ROCK_LOGGING_PATH", tmpdir) + log_dir = os.path.join(tmpdir, "sbx-test-123") + os.makedirs(log_dir) + + limits = {"max_image_push_size": "", "max_dir_upload_size": ""} + await actor.archive(dir_storage_config, image_storage_config, limits) + + mock_push.assert_called_once() + mock_upload.assert_called_once() diff --git a/tests/unit/sandbox/archive/test_sandbox_actor_restore.py b/tests/unit/sandbox/archive/test_sandbox_actor_restore.py new file mode 100644 index 0000000000..d3ab32be24 --- /dev/null +++ b/tests/unit/sandbox/archive/test_sandbox_actor_restore.py @@ -0,0 +1,160 @@ +"""Unit tests for SandboxActor.restore() method (mocked, no Ray).""" + +from unittest.mock import AsyncMock, MagicMock, create_autospec, patch + +import pytest + +from rock.deployments.docker import DockerDeployment + + +@pytest.fixture +def actor(): + from rock.sandbox.sandbox_actor import SandboxActor + + ActorClass = SandboxActor.__ray_actor_class__ + + config = MagicMock() + config.container_name = "sbx-test-123" + deployment = create_autospec(DockerDeployment, instance=True) + deployment.restart = AsyncMock() + deployment.restart_from_image = AsyncMock() + + instance = ActorClass.__new__(ActorClass) + instance._config = config + instance._deployment = deployment + instance._run_shell_command = AsyncMock() + instance._clean_container_background = MagicMock() + instance._setup_monitor = AsyncMock() + return instance + + +@pytest.fixture +def dir_storage_config(): + return { + "endpoint": "https://oss-cn-hangzhou.aliyuncs.com", + "bucket": "test-bucket", + "access_key_id": "test-ak", + "access_key_secret": "test-sk", + "region": "cn-hangzhou", + } + + +@pytest.fixture +def image_storage_config(): + return { + "registry_url": "rock-registry.cn-shanghai.cr.aliyuncs.com", + "username": "test-user", + "password": "test-pass", + } + + +class TestSandboxActorRestore: + @patch("rock.sandbox.archive.oss_storage.OssDirStorage.exists", new_callable=AsyncMock, return_value=True) + @patch("rock.sandbox.archive.oss_storage.OssDirStorage.download_to_dir", new_callable=AsyncMock) + @patch("rock.sandbox.archive.registry_v2.DockerRegistryV2ImageStorage.pull_to_local", new_callable=AsyncMock) + async def test_happy_path( + self, mock_pull, mock_download, mock_exists, actor, dir_storage_config, image_storage_config, monkeypatch + ): + import rock.env_vars as _env_vars + + monkeypatch.setattr(_env_vars, "ROCK_LOGGING_PATH", "/tmp/logs") + + await actor.restore_and_start(dir_storage_config, image_storage_config) + + mock_pull.assert_called_once() + ref_arg = mock_pull.call_args[0][0] + assert "sbx-test-123" in ref_arg + + mock_download.assert_called_once() + key_arg = mock_download.call_args[0][0] + dir_arg = mock_download.call_args[0][1] + assert "sbx-test-123" in key_arg + assert dir_arg == "/tmp/logs/sbx-test-123" + + actor._deployment.restart_from_image.assert_called_once_with(ref_arg) + actor._deployment.restart.assert_not_called() + + @patch( + "rock.sandbox.archive.registry_v2.DockerRegistryV2ImageStorage.pull_to_local", + new_callable=AsyncMock, + side_effect=RuntimeError("pull failed"), + ) + async def test_pull_failure_no_download( + self, mock_pull, actor, dir_storage_config, image_storage_config, monkeypatch + ): + import rock.env_vars as _env_vars + + monkeypatch.setattr(_env_vars, "ROCK_LOGGING_PATH", "/tmp/logs") + + with pytest.raises(RuntimeError, match="pull failed"): + await actor.restore_and_start(dir_storage_config, image_storage_config) + + actor._run_shell_command.assert_not_called() + + @patch("rock.sandbox.archive.oss_storage.OssDirStorage.exists", new_callable=AsyncMock, return_value=True) + @patch( + "rock.sandbox.archive.oss_storage.OssDirStorage.download_to_dir", + new_callable=AsyncMock, + side_effect=RuntimeError("download failed"), + ) + @patch("rock.sandbox.archive.registry_v2.DockerRegistryV2ImageStorage.pull_to_local", new_callable=AsyncMock) + async def test_download_failure_warns_and_continues( + self, mock_pull, mock_download, mock_exists, actor, dir_storage_config, image_storage_config, monkeypatch + ): + import rock.env_vars as _env_vars + + monkeypatch.setattr(_env_vars, "ROCK_LOGGING_PATH", "/tmp/logs") + + await actor.restore_and_start(dir_storage_config, image_storage_config) + + mock_pull.assert_called_once() + actor._deployment.restart_from_image.assert_called_once() + + @patch("rock.sandbox.archive.oss_storage.OssDirStorage.exists", new_callable=AsyncMock, return_value=True) + @patch("rock.sandbox.archive.oss_storage.OssDirStorage.download_to_dir", new_callable=AsyncMock) + @patch("rock.sandbox.archive.registry_v2.DockerRegistryV2ImageStorage.pull_to_local", new_callable=AsyncMock) + @patch("rock.sandbox.archive.registry_v2.DockerRegistryV2ImageStorage.delete", new_callable=AsyncMock) + async def test_restore_does_not_delete_remote( + self, + mock_delete, + mock_pull, + mock_download, + mock_exists, + actor, + dir_storage_config, + image_storage_config, + monkeypatch, + ): + import rock.env_vars as _env_vars + + monkeypatch.setattr(_env_vars, "ROCK_LOGGING_PATH", "/tmp/logs") + + await actor.restore_and_start(dir_storage_config, image_storage_config) + mock_delete.assert_not_called() + + @patch("rock.sandbox.archive.registry_v2.DockerRegistryV2ImageStorage.pull_to_local", new_callable=AsyncMock) + async def test_no_logging_path_skips_log_restore( + self, mock_pull, actor, dir_storage_config, image_storage_config, monkeypatch + ): + import rock.env_vars as _env_vars + + monkeypatch.setattr(_env_vars, "ROCK_LOGGING_PATH", None) + + await actor.restore_and_start(dir_storage_config, image_storage_config) + + mock_pull.assert_called_once() + actor._deployment.restart_from_image.assert_called_once() + + @patch("rock.sandbox.archive.oss_storage.OssDirStorage.exists", new_callable=AsyncMock, return_value=False) + @patch("rock.sandbox.archive.registry_v2.DockerRegistryV2ImageStorage.pull_to_local", new_callable=AsyncMock) + async def test_missing_key_warns_and_continues( + self, mock_pull, mock_exists, actor, dir_storage_config, image_storage_config, monkeypatch + ): + import rock.env_vars as _env_vars + + monkeypatch.setattr(_env_vars, "ROCK_LOGGING_PATH", "/tmp/logs") + + await actor.restore_and_start(dir_storage_config, image_storage_config) + + mock_pull.assert_called_once() + actor._deployment.restart_from_image.assert_called_once() diff --git a/tests/unit/sandbox/archive/test_state_machine_archive.py b/tests/unit/sandbox/archive/test_state_machine_archive.py new file mode 100644 index 0000000000..3b3a8ba899 --- /dev/null +++ b/tests/unit/sandbox/archive/test_state_machine_archive.py @@ -0,0 +1,208 @@ +"""Tests for archive-related state machine transitions.""" + +from unittest.mock import AsyncMock + +import pytest +from statemachine.exceptions import TransitionNotAllowed + +from rock.actions.sandbox.response import State +from rock.sandbox.sandbox_statemachine import SandboxStateMachine + + +@pytest.fixture +def meta_store(): + m = AsyncMock() + m.update = AsyncMock() + return m + + +class TestArchiveTransitions: + async def test_stopped_to_archiving(self, meta_store): + sm = await SandboxStateMachine.from_state_value(State.STOPPED, {}) + await sm.send("archive", sandbox_id="sbx-1", meta_store=meta_store) + assert sm.current_state.value == State.ARCHIVING + + async def test_archiving_to_archived(self, meta_store): + sm = await SandboxStateMachine.from_state_value(State.ARCHIVING, {"archive_time": "t1"}) + await sm.send("archive_done", sandbox_id="sbx-1", meta_store=meta_store) + assert sm.current_state.value == State.ARCHIVED + + async def test_archiving_to_stopped_on_failure(self, meta_store): + sm = await SandboxStateMachine.from_state_value(State.ARCHIVING, {"archive_time": "t1"}) + await sm.send("archive_failed", sandbox_id="sbx-1", meta_store=meta_store, reason="timeout") + assert sm.current_state.value == State.STOPPED + + async def test_archived_to_pending_on_restore(self, meta_store): + sm = await SandboxStateMachine.from_state_value(State.ARCHIVED, {"archive_time": "t1"}) + await sm.send("restore", sandbox_id="sbx-1", meta_store=meta_store) + assert sm.current_state.value == State.PENDING + + async def test_pending_to_running_on_alive_after_restore(self, meta_store): + sm = await SandboxStateMachine.from_state_value(State.PENDING, {"state_enter_time": "t1"}) + await sm.send("alive", sandbox_id="sbx-1", meta_store=meta_store, sandbox_info={"host_ip": "10.0.0.1"}) + assert sm.current_state.value == State.RUNNING + + async def test_pending_to_archived_on_restore_failed(self, meta_store): + sm = await SandboxStateMachine.from_state_value(State.PENDING, {"archive_time": "t1"}) + await sm.send("restore_failed", sandbox_id="sbx-1", meta_store=meta_store, reason="timeout") + assert sm.current_state.value == State.ARCHIVED + + async def test_archived_to_deleted(self, meta_store): + sm = await SandboxStateMachine.from_state_value(State.ARCHIVED, {"archive_time": "t1"}) + await sm.send( + "delete", + sandbox_id="sbx-1", + operator=AsyncMock(), + meta_store=meta_store, + reason=AsyncMock(), + ) + assert sm.current_state.value == State.DELETED + + +class TestArchiveTransitionsRejected: + async def test_running_cannot_archive(self, meta_store): + sm = await SandboxStateMachine.from_state_value(State.RUNNING, {}) + with pytest.raises(TransitionNotAllowed): + await sm.send("archive", sandbox_id="sbx-1", meta_store=meta_store) + + async def test_archiving_cannot_archive(self, meta_store): + sm = await SandboxStateMachine.from_state_value(State.ARCHIVING, {}) + with pytest.raises(TransitionNotAllowed): + await sm.send("archive", sandbox_id="sbx-1", meta_store=meta_store) + + async def test_archived_cannot_restart_directly(self, meta_store): + sm = await SandboxStateMachine.from_state_value(State.ARCHIVED, {}) + with pytest.raises(TransitionNotAllowed): + await sm.send("restart", sandbox_id="sbx-1", operator=AsyncMock(), meta_store=meta_store) + + async def test_stopped_cannot_restore(self, meta_store): + sm = await SandboxStateMachine.from_state_value(State.STOPPED, {}) + with pytest.raises(TransitionNotAllowed): + await sm.send("restore", sandbox_id="sbx-1", meta_store=meta_store) + + async def test_archiving_cannot_stop(self, meta_store): + sm = await SandboxStateMachine.from_state_value(State.ARCHIVING, {}) + with pytest.raises(TransitionNotAllowed): + await sm.send("stop", sandbox_id="sbx-1", operator=AsyncMock(), meta_store=meta_store) + + async def test_archiving_cannot_delete(self, meta_store): + sm = await SandboxStateMachine.from_state_value(State.ARCHIVING, {}) + with pytest.raises(TransitionNotAllowed): + await sm.send( + "delete", + sandbox_id="sbx-1", + operator=AsyncMock(), + meta_store=meta_store, + reason=AsyncMock(), + ) + + +class TestArchiveCleanupInCallbacks: + async def test_on_delete_cleans_archive_artifacts(self, meta_store): + dir_storage = AsyncMock() + image_storage = AsyncMock() + image_storage.registry_url = "localhost:5000" + operator = AsyncMock() + sm = await SandboxStateMachine.from_state_value( + State.ARCHIVED, + { + "sandbox_id": "sbx-1", + "archive_time": "2026-01-01T000000Z", + "spec": {"container_name": "sbx-1", "image": "img:latest"}, + }, + ) + await sm.send( + "delete", + sandbox_id="sbx-1", + operator=operator, + meta_store=meta_store, + dir_storage=dir_storage, + image_storage=image_storage, + ) + assert sm.current_state.value == State.DELETED + dir_storage.delete.assert_called_once() + assert "sbx-1" in dir_storage.delete.call_args[0][0] + image_storage.delete.assert_called_once() + assert "sbx-1" in image_storage.delete.call_args[0][0] + + async def test_on_delete_no_archive_time_skips_cleanup(self, meta_store): + dir_storage = AsyncMock() + image_storage = AsyncMock() + image_storage.registry_url = "localhost:5000" + operator = AsyncMock() + sm = await SandboxStateMachine.from_state_value( + State.STOPPED, + {"sandbox_id": "sbx-1", "spec": {"container_name": "sbx-1", "image": "img:latest"}}, + ) + await sm.send( + "delete", + sandbox_id="sbx-1", + operator=operator, + meta_store=meta_store, + dir_storage=dir_storage, + image_storage=image_storage, + ) + assert sm.current_state.value == State.DELETED + dir_storage.delete.assert_not_called() + image_storage.delete.assert_not_called() + + async def test_on_delete_without_storage_skips_cleanup(self, meta_store): + operator = AsyncMock() + sm = await SandboxStateMachine.from_state_value( + State.ARCHIVED, + { + "sandbox_id": "sbx-1", + "archive_time": "2026-01-01T000000Z", + "spec": {"container_name": "sbx-1", "image": "img:latest"}, + }, + ) + await sm.send( + "delete", + sandbox_id="sbx-1", + operator=operator, + meta_store=meta_store, + ) + assert sm.current_state.value == State.DELETED + + +class TestArchiveHookSideEffects: + async def test_on_archive_sets_fields(self, meta_store): + sm = await SandboxStateMachine.from_state_value(State.STOPPED, {}) + await sm.send("archive", sandbox_id="sbx-1", meta_store=meta_store) + info = sm.sandbox_info + assert "archive_time" not in info + assert info["state_enter_time"] is not None + assert info["state"] == State.ARCHIVING + + async def test_on_archive_done_sets_archive_time(self, meta_store): + sm = await SandboxStateMachine.from_state_value(State.ARCHIVING, {"state_enter_time": "old"}) + await sm.send("archive_done", sandbox_id="sbx-1", meta_store=meta_store) + info = sm.sandbox_info + assert info["state"] == State.ARCHIVED + assert info.get("archive_time") is not None + assert "state_enter_time" not in info + + async def test_on_archive_failed_clears_all(self, meta_store): + sm = await SandboxStateMachine.from_state_value( + State.ARCHIVING, {"archive_time": "t1", "state_enter_time": "old"} + ) + await sm.send("archive_failed", sandbox_id="sbx-1", meta_store=meta_store, reason="timeout") + info = sm.sandbox_info + assert info["state"] == State.STOPPED + assert "archive_time" not in info + assert "state_enter_time" not in info + + async def test_on_restore_sets_pending_and_keeps_archive_time(self, meta_store): + sm = await SandboxStateMachine.from_state_value(State.ARCHIVED, {"archive_time": "t1"}) + await sm.send("restore", sandbox_id="sbx-1", meta_store=meta_store) + info = sm.sandbox_info + assert info["state"] == State.PENDING + assert "archive_time" in info + assert "state_enter_time" in info + + async def test_on_restore_failed_rolls_back_to_archived(self, meta_store): + sm = await SandboxStateMachine.from_state_value(State.PENDING, {"archive_time": "t1", "state_enter_time": "t2"}) + await sm.send("restore_failed", sandbox_id="sbx-1", meta_store=meta_store, reason="timeout") + info = sm.sandbox_info + assert info["state"] == State.ARCHIVED + assert "state_enter_time" not in info diff --git a/tests/unit/sandbox/archive/test_storage_interface.py b/tests/unit/sandbox/archive/test_storage_interface.py new file mode 100644 index 0000000000..c967bb2e84 --- /dev/null +++ b/tests/unit/sandbox/archive/test_storage_interface.py @@ -0,0 +1,84 @@ +import pytest + +from rock.sandbox.archive.abstract import AbstractDirStorage + + +class TestAbstractDirStorageCannotInstantiate: + def test_cannot_instantiate(self): + with pytest.raises(TypeError): + AbstractDirStorage() + + +class InMemoryFakeDirStorage(AbstractDirStorage): + """In-memory fake for testing callers of AbstractDirStorage.""" + + def __init__(self): + self._store: dict[str, bytes] = {} + + async def upload_dir(self, local_dir: str, key: str) -> None: + self._store[key] = b"fake-tar-content" + + async def download_to_dir(self, key: str, local_dir: str) -> None: + if key not in self._store: + raise FileNotFoundError(key) + + async def delete(self, key: str) -> bool: + if key in self._store: + del self._store[key] + return True + return False + + async def exists(self, key: str) -> bool: + return key in self._store + + +class TestInMemoryFakeDirStorage: + @pytest.fixture + def storage(self): + return InMemoryFakeDirStorage() + + async def test_upload_then_exists(self, storage): + await storage.upload_dir("/tmp/fake", "mykey") + assert await storage.exists("mykey") is True + + async def test_delete(self, storage): + await storage.upload_dir("/tmp/fake", "mykey") + assert await storage.delete("mykey") is True + assert await storage.exists("mykey") is False + + async def test_delete_nonexistent(self, storage): + assert await storage.delete("nope") is False + + async def test_download_nonexistent_raises(self, storage): + with pytest.raises(FileNotFoundError): + await storage.download_to_dir("nope", "/tmp/out") + + +class TestOssDirStorageInterface: + def test_instantiates(self): + from rock.sandbox.archive.oss_storage import OssDirStorage + + s = OssDirStorage( + endpoint="https://oss-cn-hangzhou.aliyuncs.com", + bucket="test-bucket", + access_key_id="ak", + access_key_secret="sk", + region="cn-hangzhou", + ) + assert isinstance(s, AbstractDirStorage) + + def test_client_config(self): + from rock.sandbox.archive.oss_storage import OssDirStorage + + s = OssDirStorage( + endpoint="https://oss-cn-hangzhou.aliyuncs.com", + bucket="test-bucket", + access_key_id="ak", + access_key_secret="sk", + region="cn-hangzhou", + ) + cfg = s.client_config + assert cfg["endpoint"] == "https://oss-cn-hangzhou.aliyuncs.com" + assert cfg["bucket"] == "test-bucket" + assert cfg["access_key_id"] == "ak" + assert cfg["region"] == "cn-hangzhou" diff --git a/tests/unit/sandbox/operator/test_k8s_operator.py b/tests/unit/sandbox/operator/test_k8s_operator.py index d0dcb784fa..ec1dd1ba7e 100644 --- a/tests/unit/sandbox/operator/test_k8s_operator.py +++ b/tests/unit/sandbox/operator/test_k8s_operator.py @@ -244,6 +244,7 @@ async def get_current_statemachine(sandbox_id: str): return await SandboxStateMachine.from_state_value(info.get("state"), sandbox_info=info) m._get_current_statemachine = AsyncMock(side_effect=get_current_statemachine) + m._try_advance_pending = SandboxManager._try_advance_pending.__get__(m, SandboxManager) m.get_status = SandboxManager.get_status.__get__(m, SandboxManager) m._refresh_timeout = AsyncMock() return m diff --git a/tests/unit/sandbox/test_sandbox_transitions.py b/tests/unit/sandbox/test_sandbox_transitions.py index ce2756d412..83d4a81da8 100644 --- a/tests/unit/sandbox/test_sandbox_transitions.py +++ b/tests/unit/sandbox/test_sandbox_transitions.py @@ -61,6 +61,7 @@ async def get_current_statemachine(sandbox_id: str) -> SandboxStateMachine | Non m._get_current_statemachine = AsyncMock(side_effect=get_current_statemachine) m.stop = SandboxManager.stop.__get__(m, SandboxManager) + m._try_advance_pending = SandboxManager._try_advance_pending.__get__(m, SandboxManager) m.get_status = SandboxManager.get_status.__get__(m, SandboxManager) m._refresh_timeout = AsyncMock() return m