diff --git a/rock/actions/sandbox/sandbox_info.py b/rock/actions/sandbox/sandbox_info.py index 33ee9ab1d5..22ed810195 100644 --- a/rock/actions/sandbox/sandbox_info.py +++ b/rock/actions/sandbox/sandbox_info.py @@ -27,6 +27,7 @@ class SandboxInfo(TypedDict, total=False): stop_time: str delete_time: str extended_params: dict[str, str] + operator_name: str _SANDBOX_INFO_KEYS = frozenset(SandboxInfo.__annotations__.keys()) diff --git a/rock/admin/main.py b/rock/admin/main.py index a3c7537f40..019ec44957 100644 --- a/rock/admin/main.py +++ b/rock/admin/main.py @@ -42,6 +42,8 @@ from rock.logger import init_logger from rock.sandbox.gem_manager import GemManager from rock.sandbox.operator.factory import OperatorContext, OperatorFactory +from rock.sandbox.operator.registry import OperatorRegistry +from rock.sandbox.operator.routing import Router from rock.sandbox.sandbox_meta_store import SandboxMetaStore from rock.sandbox.service.sandbox_proxy_service import SandboxProxyService from rock.sandbox.service.warmup_service import WarmupService @@ -145,19 +147,44 @@ async def lifespan(app: FastAPI): # init sandbox service if args.role == "admin": - # init ray service - ray_service = RayService(rock_config.ray) - ray_service.init() - - # create operator using factory with context pattern + # Configuration-driven loading: only initialize what the YAML + # actually declared via top-level operator keys. + present = rock_config.present_operator_keys or {rock_config.runtime.operator_type} + ray_service: RayService | None = None + if "ray" in present: + ray_service = RayService(rock_config.ray) + ray_service.init() + + # build OperatorContext once with all available dependencies; the + # Factory.build call only consumes the ones each operator needs. operator_context = OperatorContext( runtime_config=rock_config.runtime, ray_service=ray_service, redis_provider=redis_provider, nacos_provider=rock_config.nacos_provider, - k8s_config=rock_config.k8s, + k8s_config=rock_config.k8s if "k8s" in present else None, + remote_config=rock_config.remote if "remote" in present else None, ) - operator = OperatorFactory.create_operator(operator_context) + + default_operator_name = rock_config.runtime.operator_type + if default_operator_name not in present: + raise ValueError( + f"runtime.operator_type={default_operator_name!r} but its config block is" + f" not present in YAML. Loaded operator keys: {sorted(present)}" + ) + operator_registry = OperatorRegistry(default_name=default_operator_name) + for name in sorted(present): + operator_registry.register(name, OperatorFactory.build(name, operator_context)) + if rock_config.runtime.operator_routing: + router = Router.from_config( + rock_config.runtime.operator_routing, + fallback_default=default_operator_name, + loaded_operators=operator_registry.loaded_names, + ) + operator_registry.set_router(router) + logger.info("operator router enabled with %d rule(s)", len(router.rules)) + else: + logger.info("no operator_routing configured; all submits go to default=%s", default_operator_name) # init service if rock_config.runtime.enable_auto_clear: @@ -166,7 +193,7 @@ async def lifespan(app: FastAPI): ray_namespace=rock_config.ray.namespace, ray_service=ray_service, enable_runtime_auto_clear=True, - operator=operator, + registry=operator_registry, meta_store=meta_store, ) else: @@ -175,7 +202,7 @@ async def lifespan(app: FastAPI): ray_namespace=rock_config.ray.namespace, ray_service=ray_service, enable_runtime_auto_clear=False, - operator=operator, + registry=operator_registry, meta_store=meta_store, ) set_sandbox_manager(sandbox_manager) diff --git a/rock/config.py b/rock/config.py index e2b2b7d8f7..608cd3cba8 100644 --- a/rock/config.py +++ b/rock/config.py @@ -218,6 +218,30 @@ def __post_init__(self): self.ports[key] = value +@dataclass +class RemoteConfig: + """Configuration for the Remote operator.""" + + api_endpoint: str = "" + + api_key: str = "" + + rocklet_port: int = 8000 + + header_sandbox_id: str = "X-Sandbox-Id" + + header_sandbox_port: str = "X-Sandbox-Port" + + def resolved_api_key(self) -> str: + """Return the effective api_key, env var ROCK_REMOTE_API_KEY takes precedence.""" + import os + + value = os.environ.get("ROCK_REMOTE_API_KEY", "") + if value: + return value + return self.api_key + + @dataclass class K8sConfig: """Kubernetes configuration for K8s operator.""" @@ -257,6 +281,16 @@ class RuntimeConfig: python_env_path: str = field(default_factory=lambda: env_vars.ROCK_PYTHON_ENV_PATH) envhub_db_url: str = field(default_factory=lambda: env_vars.ROCK_ENVHUB_DB_URL) operator_type: str = "ray" + operator_routing: dict | None = None + """Routing rules consumed by Router.from_config. Kept as a raw dict so the + routing layer owns its schema. Example:: + + operator_routing: + default: ray # optional; falls back to operator_type + rules: + - match: {image_prefix: "reg.example.com/remote/"} + target: remote + """ standard_spec: StandardSpec = field(default_factory=StandardSpec) max_allowed_spec: StandardSpec = field(default_factory=lambda: StandardSpec(cpus=16, memory="64g")) use_standard_spec_only: bool = False @@ -324,10 +358,16 @@ def _resolve_k8s_template_includes(k8s_dict: dict, base_dir: Path) -> None: k8s_dict["templates"] = merged +# Top-level YAML keys whose presence triggers loading of the corresponding +# operator. Single source of truth for the configuration-driven loader. +OPERATOR_CONFIG_KEYS: tuple[str, ...] = ("ray", "k8s", "remote") + + @dataclass class RockConfig: ray: RayConfig = field(default_factory=RayConfig) k8s: K8sConfig = field(default_factory=K8sConfig) + remote: RemoteConfig | None = None warmup: WarmupConfig = field(default_factory=WarmupConfig) nacos: NacosConfig = field(default_factory=NacosConfig) redis: RedisConfig = field(default_factory=RedisConfig) @@ -338,6 +378,10 @@ class RockConfig: scheduler: SchedulerConfig = field(default_factory=SchedulerConfig) database: DatabaseConfig = field(default_factory=DatabaseConfig) nacos_provider: NacosConfigProvider | None = None + # Operator config keys actually present in the loaded YAML. Populated by + # from_env after merging _base. Drives configuration-driven operator + # loading: an operator is registered iff its key is in this set. + present_operator_keys: set[str] = field(default_factory=set, init=False) @classmethod def from_env(cls, config_path: str | None = None): @@ -374,6 +418,8 @@ def from_env(cls, config_path: str | None = None): if "k8s" in config: _resolve_k8s_template_includes(config["k8s"], config_file.parent) kwargs["k8s"] = K8sConfig(**config["k8s"]) + if "remote" in config: + kwargs["remote"] = RemoteConfig(**config["remote"]) if "warmup" in config: kwargs["warmup"] = WarmupConfig(**config["warmup"]) if "nacos" in config: @@ -393,7 +439,11 @@ def from_env(cls, config_path: str | None = None): if "database" in config: kwargs["database"] = DatabaseConfig(**config["database"]) - return cls(**kwargs) + instance = cls(**kwargs) + # Record which operator config blocks were actually present so the + # Registry can load exactly what the YAML declares. + instance.present_operator_keys = {k for k in OPERATOR_CONFIG_KEYS if k in config} + return instance # ============================================================================ # Merging Rules: diff --git a/rock/sandbox/gem_manager.py b/rock/sandbox/gem_manager.py index 0858852df3..1e8e5af66e 100644 --- a/rock/sandbox/gem_manager.py +++ b/rock/sandbox/gem_manager.py @@ -16,6 +16,7 @@ from rock.admin.proto.response import SandboxStartResponse, SandboxStatusResponse from rock.config import RockConfig from rock.deployments.config import DockerDeploymentConfig +from rock.sandbox.operator.registry import OperatorRegistry from rock.sandbox.sandbox_actor import SandboxActor from rock.sandbox.sandbox_manager import SandboxManager from rock.sandbox.sandbox_meta_store import SandboxMetaStore @@ -25,19 +26,19 @@ class GemManager(SandboxManager): def __init__( self, rock_config: RockConfig, - meta_store: SandboxMetaStore | None = None, + meta_store: SandboxMetaStore, + registry: OperatorRegistry, ray_namespace: str = env_vars.ROCK_RAY_NAMESPACE, ray_service: RayService | None = None, enable_runtime_auto_clear: bool = False, - operator=None, ): super().__init__( rock_config, meta_store=meta_store, + registry=registry, ray_namespace=ray_namespace, ray_service=ray_service, enable_runtime_auto_clear=enable_runtime_auto_clear, - operator=operator, ) async def env_make(self, env_id: str) -> EnvMakeResponse: diff --git a/rock/sandbox/operator/factory.py b/rock/sandbox/operator/factory.py index fd3d31aaa7..cc1eb76d58 100644 --- a/rock/sandbox/operator/factory.py +++ b/rock/sandbox/operator/factory.py @@ -4,7 +4,7 @@ from typing import Any from rock.admin.core.ray_service import RayService -from rock.config import K8sConfig, RuntimeConfig +from rock.config import K8sConfig, RemoteConfig, RuntimeConfig from rock.logger import init_logger from rock.sandbox.operator.abstract import AbstractOperator from rock.sandbox.operator.k8s.operator import K8sOperator @@ -30,6 +30,8 @@ class OperatorContext: # K8s operator dependencies k8s_config: K8sConfig | None = None nacos_provider: NacosConfigProvider | None = None + # Remote operator dependencies + remote_config: RemoteConfig | None = None # Future operator dependencies can be added here without breaking existing code extra_params: dict[str, Any] = field(default_factory=dict) @@ -38,25 +40,21 @@ class OperatorFactory: """Factory class for creating operator instances. Uses the Context Object pattern to avoid parameter explosion as new - operator types are added. + operator types are added. ``build`` constructs a single operator by name, + while ``create_operator`` is kept for backward compatibility (loads the + single operator selected by ``runtime_config.operator_type``). """ @staticmethod - def create_operator(context: OperatorContext) -> AbstractOperator: - """Create an operator instance based on the runtime configuration. - - Args: - context: OperatorContext containing all necessary dependencies + def build(name: str, context: OperatorContext) -> AbstractOperator: + """Construct one operator instance by config-key name. - Returns: - AbstractOperator: The created operator instance - - Raises: - ValueError: If operator_type is not supported or required dependencies are missing + ``name`` matches the top-level YAML key that triggered loading + (one of ``OPERATOR_CONFIG_KEYS``: ``ray``/``k8s``/``remote``). """ - operator_type = context.runtime_config.operator_type.lower() + key = name.lower() - if operator_type == "ray": + if key == "ray": if context.ray_service is None: raise ValueError("RayService is required for RayOperator") logger.info("Creating RayOperator") @@ -66,7 +64,7 @@ def create_operator(context: OperatorContext) -> AbstractOperator: if context.nacos_provider is not None: ray_operator.set_nacos_provider(context.nacos_provider) return ray_operator - elif operator_type == "k8s": + elif key == "k8s": if context.k8s_config is None: raise ValueError("K8sConfig is required for K8sOperator") logger.info("Creating K8sOperator") @@ -76,5 +74,21 @@ def create_operator(context: OperatorContext) -> AbstractOperator: if context.nacos_provider is not None: k8s_operator.set_nacos_provider(context.nacos_provider) return k8s_operator + elif key == "remote": + if context.remote_config is None: + raise ValueError("RemoteConfig is required for RemoteOperator") + # Lazy import to avoid pulling httpx into ray/k8s-only deployments. + from rock.sandbox.operator.remote.operator import RemoteOperator + + logger.info("Creating RemoteOperator endpoint=%s", context.remote_config.api_endpoint) + remote_operator = RemoteOperator(remote_config=context.remote_config) + if context.redis_provider is not None: + remote_operator.set_redis_provider(context.redis_provider) + return remote_operator else: - raise ValueError(f"Unsupported operator type: {operator_type}. Supported types: ray, kubernetes") + raise ValueError(f"Unsupported operator name: {name!r}. Supported: ray, k8s, remote") + + @staticmethod + def create_operator(context: OperatorContext) -> AbstractOperator: + """Backward-compatible single-operator creation by ``operator_type``.""" + return OperatorFactory.build(context.runtime_config.operator_type, context) diff --git a/rock/sandbox/operator/registry.py b/rock/sandbox/operator/registry.py new file mode 100644 index 0000000000..6c34ae1dff --- /dev/null +++ b/rock/sandbox/operator/registry.py @@ -0,0 +1,79 @@ +"""Operator registry: name → AbstractOperator instance. + +Two dispatch paths share the registry: + * Submit path: ``resolve(RouteContext)`` → router → operator + * Operate path: ``get(operator_name)`` → operator (name from sandbox meta) +""" + +from __future__ import annotations + +from rock.logger import init_logger +from rock.sandbox.operator.abstract import AbstractOperator +from rock.sandbox.operator.routing.context import RouteContext +from rock.sandbox.operator.routing.router import Router + +logger = init_logger(__name__) + + +class OperatorRegistry: + """Holds all loaded operators and dispatches by name or by route ctx. + + Built once at startup. ``register`` is intended to be called only + during initialization; runtime mutation is not supported. + """ + + def __init__(self, default_name: str) -> None: + self._default_name: str = default_name + self._operators: dict[str, AbstractOperator] = {} + self._router: Router | None = None + + # ------------------------------------------------------------------ + # Registration (startup-only) + # ------------------------------------------------------------------ + + def register(self, name: str, operator: AbstractOperator) -> None: + if name in self._operators: + raise ValueError(f"operator {name!r} already registered") + self._operators[name] = operator + logger.info("Operator registered: %s (%s)", name, type(operator).__name__) + + def set_router(self, router: Router) -> None: + self._router = router + + # ------------------------------------------------------------------ + # Lookup + # ------------------------------------------------------------------ + + @property + def default_name(self) -> str: + return self._default_name + + @property + def loaded_names(self) -> set[str]: + return set(self._operators) + + def get(self, name: str | None) -> AbstractOperator: + """Resolve operator by name; empty/None → default. + + Raises ``KeyError`` when the name was never loaded — callers + should let this surface to alert ops that a sandbox's bound + operator has been removed from config. + """ + effective = name or self._default_name + op = self._operators.get(effective) + if op is None: + raise KeyError( + f"operator {effective!r} is not loaded; " + f"loaded={sorted(self._operators)}" + ) + return op + + def resolve(self, ctx: RouteContext) -> tuple[str, AbstractOperator]: + """Submit-time entry: route by ctx, return (name, operator).""" + if self._router is None: + # No routing configured → everything goes to default. This is + # the legacy single-operator path. + return self._default_name, self.get(self._default_name) + name, routed_by = self._router.route(ctx) + logger.info("routing: operator_name=%s routed_by=%s", name, routed_by) + return name, self.get(name) diff --git a/rock/sandbox/operator/remote/__init__.py b/rock/sandbox/operator/remote/__init__.py new file mode 100644 index 0000000000..128afc0981 --- /dev/null +++ b/rock/sandbox/operator/remote/__init__.py @@ -0,0 +1,5 @@ +"""RemoteOperator: delegates sandbox lifecycle to an external sandbox API.""" + +from rock.sandbox.operator.remote.operator import RemoteOperator + +__all__ = ["RemoteOperator"] diff --git a/rock/sandbox/operator/remote/client.py b/rock/sandbox/operator/remote/client.py new file mode 100644 index 0000000000..805110c2da --- /dev/null +++ b/rock/sandbox/operator/remote/client.py @@ -0,0 +1,158 @@ +"""Async httpx client for the remote sandbox control-plane API.""" + +from __future__ import annotations + +import logging +from typing import Any + +import httpx + +from rock.config import RemoteConfig +from rock.utils import REQUEST_TIMEOUT_SECONDS + +logger = logging.getLogger(__name__) + + +class RemoteApiError(RuntimeError): + """Non-2xx response from the remote sandbox API.""" + + def __init__(self, status_code: int, message: str, body: Any = None) -> None: + super().__init__(f"remote api {status_code}: {message}") + self.status_code = status_code + self.body = body + + +class RemoteClient: + """Async client for the remote sandbox API.""" + + def __init__(self, config: RemoteConfig) -> None: + if not config.api_endpoint: + raise ValueError("RemoteConfig.api_endpoint is required") + self._config = config + self._client = httpx.AsyncClient( + base_url=config.api_endpoint.rstrip("/"), + timeout=REQUEST_TIMEOUT_SECONDS, + verify=False, + ) + + @property + def config(self) -> RemoteConfig: + return self._config + + def _headers(self) -> dict[str, str]: + api_key = self._config.resolved_api_key() + if not api_key: + raise RuntimeError( + "RemoteConfig has no api_key (neither inline nor env var resolved a value)" + ) + return { + "X-API-Key": api_key, + "Content-Type": "application/json", + "Accept": "application/json", + } + + async def close(self) -> None: + await self._client.aclose() + + async def _request( + self, + method: str, + path: str, + *, + json_body: dict[str, Any] | None = None, + allow_404: bool = False, + extra_headers: dict[str, str] | None = None, + ) -> dict[str, Any] | None: + try: + headers = self._headers() + if extra_headers: + headers.update(extra_headers) + response = await self._client.request( + method, + path, + headers=headers, + json=json_body, + ) + except httpx.HTTPError as exc: + logger.error("remote api transport error %s %s: %r", method, path, exc) + raise RemoteApiError(0, f"transport error: {exc}") from exc + + if allow_404 and response.status_code == 404: + return None + + if 200 <= response.status_code < 300: + if response.status_code == 204 or not response.content: + return {} + try: + return response.json() + except ValueError: + return {"raw": response.text} + + # Non-2xx: surface a structured error including the body for log triage. + body: Any + try: + body = response.json() + except ValueError: + body = response.text + message = "" + if isinstance(body, dict): + message = str(body.get("message") or body.get("error") or body) + else: + message = str(body) + logger.error( + "remote api %s %s -> %s: %s", + method, + path, + response.status_code, + message, + ) + raise RemoteApiError(response.status_code, message, body) + + # ------------------------------------------------------------------ + # Endpoints (per docs/openapi.yml) + # ------------------------------------------------------------------ + + async def create(self, payload: dict[str, Any]) -> dict[str, Any]: + """POST /sandboxes (async build mode)""" + body = await self._request( + "POST", "/sandboxes", json_body=payload, + extra_headers={"Async-Build": "true"}, + ) + if not body: + raise RemoteApiError(0, "create returned empty body") + return body + + async def get(self, sandbox_id: str) -> dict[str, Any] | None: + """GET /sandboxes/{id} — returns None on 404.""" + return await self._request("GET", f"/sandboxes/{sandbox_id}", allow_404=True) + + async def stop(self, sandbox_id: str) -> bool: + """POST /sandboxes/{id}/pause — returns False on 404.""" + body = await self._request("POST", f"/sandboxes/{sandbox_id}/pause", allow_404=True) + return body is not None + + async def restart(self, sandbox_id: str, timeout_seconds: int) -> dict[str, Any] | None: + """POST /sandboxes/{id}/connect — returns None on 404.""" + return await self._request( + "POST", + f"/sandboxes/{sandbox_id}/connect", + json_body={"timeout": int(timeout_seconds)}, + allow_404=True, + ) + + async def delete(self, sandbox_id: str) -> bool: + """DELETE /sandboxes/{id} — returns False on 404.""" + body = await self._request("DELETE", f"/sandboxes/{sandbox_id}", allow_404=True) + return body is not None + + async def keep_alive(self, sandbox_id: str) -> None: + """POST /sandboxes/{id}/refreshes""" + await self._request("POST", f"/sandboxes/{sandbox_id}/refreshes") + + async def set_timeout(self, sandbox_id: str, timeout_seconds: int) -> None: + """POST /sandboxes/{id}/timeout""" + await self._request( + "POST", + f"/sandboxes/{sandbox_id}/timeout", + json_body={"timeout": int(timeout_seconds)}, + ) diff --git a/rock/sandbox/operator/remote/mapping.py b/rock/sandbox/operator/remote/mapping.py new file mode 100644 index 0000000000..5430ed4942 --- /dev/null +++ b/rock/sandbox/operator/remote/mapping.py @@ -0,0 +1,158 @@ +"""Bidirectional mapping between ROCK domain types and remote sandbox API payloads.""" + +from __future__ import annotations + +from typing import Any + +from rock.actions.sandbox.response import State +from rock.actions.sandbox.sandbox_info import SandboxInfo +from rock.deployments.config import DockerDeploymentConfig +from rock.deployments.constants import Status +from rock.utils.format import parse_size_to_bytes + + +# --------------------------------------------------------------------------- +# State mapping +# --------------------------------------------------------------------------- + +# Remote SandboxState -> ROCK State +_REMOTE_TO_ROCK_STATE: dict[str, State] = { + "running": State.RUNNING, + "paused": State.STOPPED, + "building": State.PENDING, + "build_failed": State.STOPPED, +} + + +def map_remote_state(value: str | None) -> State: + """Translate a remote state string to ROCK State. Unknown values default to PENDING.""" + if not value: + return State.PENDING + return _REMOTE_TO_ROCK_STATE.get(value.lower(), State.PENDING) + + +# Synthetic phase entry for the remote lifecycle. +_REMOTE_PHASE_NAME = "remote_sandbox" +_REMOTE_TO_PHASE_STATUS: dict[str, Status] = { + "running": Status.SUCCESS, + "paused": Status.SUCCESS, + "building": Status.RUNNING, + "build_failed": Status.FAILED, +} + + +def _build_phases(raw_state: str | None) -> dict[str, dict[str, str]]: + """Synthesize a phases dict from the remote state.""" + state_str = (raw_state or "").lower() + phase_status = _REMOTE_TO_PHASE_STATUS.get(state_str, Status.WAITING) + return { + _REMOTE_PHASE_NAME: { + "status": phase_status.value, + "message": state_str or "pending", + } + } + + +# --------------------------------------------------------------------------- +# Outbound: DockerDeploymentConfig -> NewSandbox JSON body +# --------------------------------------------------------------------------- + + +def _memory_str_to_mb(memory: str | None) -> int | None: + """Convert ROCK memory string ('8g', '4096m') to MiB; return None if invalid.""" + if not memory: + return None + try: + return int(parse_size_to_bytes(memory) // (1024 * 1024)) + except ValueError: + return None + + +def to_new_sandbox_payload(config: DockerDeploymentConfig) -> dict[str, Any]: + """Build the NewSandbox JSON body for POST /sandboxes.""" + payload: dict[str, Any] = {} + + image = (config.image or "").strip() + if image: + payload["fromImage"] = image + + if config.cpus is not None: + # Remote CPUCount is int>=1; ROCK allows fractional cpus, so round up. + cpu_count = max(1, int(config.cpus + 0.999)) + payload["cpuCount"] = cpu_count + + memory_mb = _memory_str_to_mb(config.memory) + if memory_mb is not None: + payload["memoryMB"] = memory_mb + + # Auto-clear window in minutes -> seconds for the remote TTL. + if config.auto_clear_time_minutes: + payload["timeout"] = int(config.auto_clear_time_minutes) * 60 + + # autoPause=true: TTL expiry -> paused (not kill), aligns with ROCK STOPPED state + payload["autoPause"] = True + + metadata: dict[str, str] = { + "rock.container_name": config.container_name or "", + } + # extended_params -> remote metadata (string-only) + for key, value in (config.extended_params or {}).items(): + if value is None: + continue + metadata[f"rock.ext.{key}"] = str(value) + payload["metadata"] = {k: v for k, v in metadata.items() if v} + + return payload + + +# --------------------------------------------------------------------------- +# Inbound: remote JSON -> SandboxInfo +# --------------------------------------------------------------------------- + + +def from_sandbox_response( + body: dict[str, Any], + *, + config: DockerDeploymentConfig | None = None, +) -> SandboxInfo: + """Build SandboxInfo from a POST /sandboxes response.""" + info: SandboxInfo = { + "sandbox_id": body.get("sandboxID", ""), + "host_name": body.get("sandboxID", ""), + "host_ip": "", + "image": (config.image if config else "") or "", + "state": State.PENDING, + "extended_params": { + "remote.sandbox_domain": body.get("domain") or "", + "remote.traffic_access_token": body.get("trafficAccessToken") or "", + }, + "port_mapping": {}, + "phases": _build_phases(body.get("state")), + } + if config is not None: + info["cpus"] = float(config.cpus) if config.cpus is not None else 0.0 + info["memory"] = config.memory or "" + return info + + +def from_sandbox_detail(body: dict[str, Any]) -> SandboxInfo: + """Build SandboxInfo from a ``GET /sandboxes/{id}`` response (``SandboxDetail``).""" + raw_state = body.get("state") + info: SandboxInfo = { + "sandbox_id": body.get("sandboxID", ""), + "host_name": body.get("sandboxID", ""), + "host_ip": "", + "image": "", + "state": map_remote_state(raw_state), + "extended_params": { + "remote.sandbox_domain": body.get("domain") or "", + }, + "start_time": body.get("startedAt") or "", + "port_mapping": {}, + "phases": _build_phases(raw_state), + } + if body.get("cpuCount") is not None: + info["cpus"] = float(body["cpuCount"]) + if body.get("memoryMB") is not None: + info["memory"] = f"{int(body['memoryMB'])}m" + return info diff --git a/rock/sandbox/operator/remote/operator.py b/rock/sandbox/operator/remote/operator.py new file mode 100644 index 0000000000..9877554968 --- /dev/null +++ b/rock/sandbox/operator/remote/operator.py @@ -0,0 +1,161 @@ +"""RemoteOperator — delegates sandbox lifecycle to an external sandbox API.""" + +from __future__ import annotations + +from typing import Any + +from rock.actions.sandbox.response import State +from rock.actions.sandbox.sandbox_info import SandboxInfo +from rock.common.constants import StopReason +from rock.config import RemoteConfig +from rock.deployments.config import DeploymentConfig, DockerDeploymentConfig +from rock.logger import init_logger +from rock.sandbox.operator.abstract import AbstractOperator +from rock.sandbox.operator.remote.client import RemoteApiError, RemoteClient +from rock.sandbox.operator.remote.mapping import ( + from_sandbox_detail, + from_sandbox_response, + to_new_sandbox_payload, +) + +logger = init_logger(__name__) + + +class RemoteOperator(AbstractOperator): + + def __init__(self, remote_config: RemoteConfig) -> None: + self._remote_config = remote_config + self._client = RemoteClient(remote_config) + + @property + def client(self) -> RemoteClient: + return self._client + + async def close(self) -> None: + await self._client.close() + + # ------------------------------------------------------------------ + # AbstractOperator contract + # ------------------------------------------------------------------ + + async def submit(self, config: DeploymentConfig, user_info: dict | None = None) -> SandboxInfo: + if not isinstance(config, DockerDeploymentConfig): + raise TypeError( + f"RemoteOperator.submit only supports DockerDeploymentConfig, got {type(config).__name__}" + ) + local_id = config.container_name + payload = to_new_sandbox_payload(config) + logger.info( + f"[{local_id}] remote submit -> POST /sandboxes (image={payload.get('fromImage')})" + ) + body = await self._client.create(payload) + remote_id = body.get("sandboxID") or "" + if not remote_id: + raise RemoteApiError(0, f"create returned no sandboxID: {body!r}") + + sandbox_info: SandboxInfo = from_sandbox_response(body, config=config) + sandbox_info["sandbox_id"] = local_id + sandbox_info["host_name"] = remote_id + # placeholder for state machine on_restart guard check + sandbox_info["host_ip"] = self._remote_config.api_endpoint + + ext = dict(sandbox_info.get("extended_params") or {}) + ext["remote.sandbox_id"] = remote_id + sandbox_info["extended_params"] = ext + + # persist into config so the DB spec column carries remote_id for restart/delete + config.extended_params["remote.sandbox_id"] = remote_id + + if user_info: + for key in ("user_id", "experiment_id", "namespace", "rock_authorization"): + value = user_info.get(key) + if value is not None: + sandbox_info[key] = value # type: ignore[literal-required] + + logger.info( + f"[{local_id}] remote submit OK: remote_id={remote_id} domain={ext.get('remote.sandbox_domain')}" + ) + return sandbox_info + + async def restart( + self, config: DeploymentConfig, host_ip: str | None = None + ) -> SandboxInfo: + """Resume a paused sandbox via POST /connect.""" + if not isinstance(config, DockerDeploymentConfig): + raise TypeError("RemoteOperator.restart requires DockerDeploymentConfig") + local_id = config.container_name + remote_id = await self._resolve_remote_id(local_id) + if not remote_id: + remote_id = config.extended_params.get("remote.sandbox_id", "") + if not remote_id: + raise RemoteApiError( + 0, f"restart: no remote sandbox id recorded for {local_id}" + ) + timeout_seconds = int((config.auto_clear_time_minutes or 0) * 60) + body = await self._client.restart(remote_id, timeout_seconds=timeout_seconds) + if body is None: + return {"sandbox_id": local_id, "host_name": remote_id, "state": State.STOPPED} + info = from_sandbox_response(body, config=config) + info["sandbox_id"] = local_id + info["host_name"] = remote_id + info["state"] = State.RUNNING + return info + + async def get_status(self, sandbox_id: str) -> SandboxInfo | None: + remote_id = await self._resolve_remote_id(sandbox_id) + if not remote_id: + logger.debug(f"[{sandbox_id}] remote get_status: no remote id, returning None") + return None + body = await self._client.get(remote_id) + if body is None: + return None + info = from_sandbox_detail(body) + info["sandbox_id"] = sandbox_id + info["host_name"] = remote_id + info["host_ip"] = self._remote_config.api_endpoint + return info + + async def stop(self, sandbox_id: str, reason: StopReason = StopReason.MANUAL) -> bool: + """Pause the remote sandbox. Kill is reserved for delete.""" + remote_id = await self._resolve_remote_id(sandbox_id) + if not remote_id: + logger.warning(f"[{sandbox_id}] remote stop: no remote id, treating as already gone") + return True + logger.info(f"[{sandbox_id}] remote stop -> POST /sandboxes/{remote_id}/pause (reason={reason})") + return await self._client.stop(remote_id) + + async def delete(self, config: DeploymentConfig, host_ip: str | None = None) -> bool: + if not isinstance(config, DockerDeploymentConfig): + return True + local_id = config.container_name + remote_id = await self._resolve_remote_id(local_id) + if not remote_id: + remote_id = config.extended_params.get("remote.sandbox_id", "") + if not remote_id: + logger.info(f"[{local_id}] remote delete: no remote id, no-op") + return True + try: + return await self._client.delete(remote_id) + except RemoteApiError as exc: + if exc.status_code == 404: + return True + raise + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + async def _resolve_remote_id(self, sandbox_id: str) -> str: + """Look up the remote sandboxID from redis alive key.""" + info: dict[str, Any] | None = None + try: + info = await self.get_sandbox_info_from_redis(sandbox_id) + except RuntimeError: + return "" + if not info: + return "" + remote_id = info.get("host_name") or "" + if remote_id: + return remote_id + ext = info.get("extended_params") or {} + return ext.get("remote.sandbox_id") or "" diff --git a/rock/sandbox/operator/routing/__init__.py b/rock/sandbox/operator/routing/__init__.py new file mode 100644 index 0000000000..30e838c3ee --- /dev/null +++ b/rock/sandbox/operator/routing/__init__.py @@ -0,0 +1,13 @@ +"""Operator routing layer. + +Submit-time mechanism that maps a sandbox creation request to one of the +registered operators based on declarative rules. GET-class operations +(get_status / stop / restart / delete) bypass this layer and dispatch by +the ``operator_name`` field stored in sandbox meta. +""" + +from rock.sandbox.operator.routing.context import RouteContext +from rock.sandbox.operator.routing.matcher import MATCHER_REGISTRY, Matcher +from rock.sandbox.operator.routing.router import Router, RoutingRule + +__all__ = ["MATCHER_REGISTRY", "Matcher", "RouteContext", "Router", "RoutingRule"] diff --git a/rock/sandbox/operator/routing/context.py b/rock/sandbox/operator/routing/context.py new file mode 100644 index 0000000000..ed4777e5c9 --- /dev/null +++ b/rock/sandbox/operator/routing/context.py @@ -0,0 +1,63 @@ +"""Unified routing context: signal carrier for routing decisions. + +Decouples Router from request/header DTOs so it can be built from any code +path that knows the deployment intent (start path, future warmup paths, etc.). +Adding a new routing dimension only requires exposing a field here and +registering a matcher. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from rock.admin.proto.request import ClusterInfo, UserInfo + from rock.deployments.config import DeploymentConfig + + +@dataclass(frozen=True) +class RouteContext: + """Frozen snapshot of all signals that may participate in routing. + + Frozen because routing is referentially transparent: same context → + same operator. Any mutation indicates a logic bug. + """ + + image: str + image_os: str + accelerator_type: str | None + num_gpus: float | None + use_kata_runtime: bool + namespace: str + cluster: str + user_id: str + experiment_id: str + + @classmethod + def from_deployment( + cls, + config: "DeploymentConfig", + user_info: "UserInfo | None" = None, + cluster_info: "ClusterInfo | None" = None, + ) -> "RouteContext": + """Construct from the values reaching ``SandboxManager.start_async``. + + Tolerates DeploymentConfig variants that lack docker-specific fields + (e.g. ``LocalDeploymentConfig`` has no ``image``) by reading via + ``getattr``. Header defaults ("default") are propagated as-is — + matchers should treat "default" as a real value and not as ``None``. + """ + user_info = user_info or {} + cluster_info = cluster_info or {} + return cls( + image=getattr(config, "image", "") or "", + image_os=getattr(config, "image_os", "") or "", + accelerator_type=getattr(config, "accelerator_type", None), + num_gpus=getattr(config, "num_gpus", None), + use_kata_runtime=bool(getattr(config, "use_kata_runtime", False)), + namespace=user_info.get("namespace", "default") or "default", + cluster=cluster_info.get("cluster_name", "default") or "default", + user_id=user_info.get("user_id", "default") or "default", + experiment_id=user_info.get("experiment_id", "default") or "default", + ) diff --git a/rock/sandbox/operator/routing/matcher.py b/rock/sandbox/operator/routing/matcher.py new file mode 100644 index 0000000000..abd48247fb --- /dev/null +++ b/rock/sandbox/operator/routing/matcher.py @@ -0,0 +1,108 @@ +"""Matcher abstraction + registry. + +Adding a new routing dimension is a 2-step extension: + 1. Implement a ``Matcher`` subclass reading fields from ``RouteContext``. + 2. Register it under a yaml key in ``MATCHER_REGISTRY``. + +Router and config schema stay untouched. +""" + +from __future__ import annotations + +import re +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any + +from rock.sandbox.operator.routing.context import RouteContext + + +class Matcher(ABC): + """One predicate over RouteContext. + + Matchers must be stateless after construction — they are reused + across all submit requests. + """ + + yaml_key: str = "" # set by subclasses; surfaced in routing logs + + @abstractmethod + def match(self, ctx: RouteContext) -> bool: + ... + + @abstractmethod + def summary(self) -> str: + """Short, log-friendly description, e.g. ``image_prefix=remote/``.""" + ... + + +@dataclass(frozen=True) +class ImagePrefixMatcher(Matcher): + """Match when ``ctx.image`` starts with ``prefix``.""" + + yaml_key: str = "image_prefix" + prefix: str = "" + + def match(self, ctx: RouteContext) -> bool: + return bool(self.prefix) and ctx.image.startswith(self.prefix) + + def summary(self) -> str: + return f"image_prefix={self.prefix}" + + +@dataclass(frozen=True) +class ImagePatternMatcher(Matcher): + """Match when ``ctx.image`` matches regex ``pattern``. + + Pattern is compiled at construction; invalid patterns fail-fast at + config load time. + """ + + yaml_key: str = "image_pattern" + pattern: str = "" + + def match(self, ctx: RouteContext) -> bool: + if not self.pattern: + return False + return re.search(self.pattern, ctx.image) is not None + + def summary(self) -> str: + return f"image_pattern={self.pattern}" + + +def _build_image_prefix(value: Any) -> Matcher: + if not isinstance(value, str): + raise ValueError(f"image_prefix must be a string, got {type(value).__name__}") + return ImagePrefixMatcher(prefix=value) + + +def _build_image_pattern(value: Any) -> Matcher: + if not isinstance(value, str): + raise ValueError(f"image_pattern must be a string, got {type(value).__name__}") + # Compile-check the pattern early so misconfiguration fails at startup. + try: + re.compile(value) + except re.error as e: + raise ValueError(f"image_pattern is not a valid regex: {value!r} ({e})") from e + return ImagePatternMatcher(pattern=value) + + +# yaml-key → builder. Builders take the raw yaml value and return a Matcher. +# To extend: add a new (key, builder) entry; no other code changes needed. +MATCHER_REGISTRY: dict[str, callable] = { + "image_prefix": _build_image_prefix, + "image_pattern": _build_image_pattern, +} + + +def build_matcher(key: str, value: Any) -> Matcher: + """Build a Matcher from a yaml ``match`` entry. + + Raises ``ValueError`` for unknown keys — startup fail-fast prevents + typos from silently routing all traffic to default. + """ + builder = MATCHER_REGISTRY.get(key) + if builder is None: + registered = ", ".join(sorted(MATCHER_REGISTRY)) or "(none)" + raise ValueError(f"unknown match key {key!r}; registered keys: {registered}") + return builder(value) diff --git a/rock/sandbox/operator/routing/router.py b/rock/sandbox/operator/routing/router.py new file mode 100644 index 0000000000..87a4f5f661 --- /dev/null +++ b/rock/sandbox/operator/routing/router.py @@ -0,0 +1,106 @@ +"""Sequential rule-based router. + +Single semantics: walk rules in declaration order; first whose ``match`` +block fully matches wins; otherwise fall back to ``default``. + +Multiple match keys within one rule are AND-combined; OR is expressed by +splitting into multiple rules. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +from rock.logger import init_logger +from rock.sandbox.operator.routing.context import RouteContext +from rock.sandbox.operator.routing.matcher import Matcher, build_matcher + +logger = init_logger(__name__) + + +@dataclass(frozen=True) +class RoutingRule: + """One yaml rule entry. Matchers are AND-combined.""" + + target: str + matchers: tuple[Matcher, ...] + + def match(self, ctx: RouteContext) -> bool: + if not self.matchers: + # Defensive: an empty match block would always fire and shadow + # every later rule. Validated to be non-empty at parse time. + return False + return all(m.match(ctx) for m in self.matchers) + + def summary(self) -> str: + return ",".join(m.summary() for m in self.matchers) + + +@dataclass +class Router: + default: str + rules: list[RoutingRule] = field(default_factory=list) + + def route(self, ctx: RouteContext) -> tuple[str, str]: + """Return (operator_name, routed_by) where routed_by is a log tag.""" + for idx, rule in enumerate(self.rules, start=1): + if rule.match(ctx): + return rule.target, f"rule#{idx}({rule.summary()})" + return self.default, "default" + + @classmethod + def from_config( + cls, + routing_cfg: dict[str, Any] | None, + fallback_default: str, + loaded_operators: set[str], + ) -> "Router": + """Build router from yaml ``runtime.operator_routing`` block. + + Args: + routing_cfg: Parsed ``operator_routing`` dict, or None when the + block is absent — every submit goes to default. + fallback_default: Used when ``routing_cfg.default`` is omitted. + Caller passes ``runtime.operator_type`` here to preserve + the legacy semantics. + loaded_operators: Names of operators actually loaded into the + registry. Used to validate ``default`` and rule targets. + """ + cfg = routing_cfg or {} + default = cfg.get("default") or fallback_default + if default not in loaded_operators: + raise ValueError( + f"routing default {default!r} is not a loaded operator; " + f"loaded={sorted(loaded_operators)}" + ) + + raw_rules = cfg.get("rules") or [] + if not isinstance(raw_rules, list): + raise ValueError("operator_routing.rules must be a list") + + rules: list[RoutingRule] = [] + for i, raw in enumerate(raw_rules, start=1): + if not isinstance(raw, dict): + raise ValueError(f"rule #{i}: must be a mapping, got {type(raw).__name__}") + target = raw.get("target") + if not target or not isinstance(target, str): + raise ValueError(f"rule #{i}: 'target' is required and must be a string") + if target not in loaded_operators: + raise ValueError( + f"rule #{i}: target {target!r} is not a loaded operator; " + f"loaded={sorted(loaded_operators)}" + ) + match_block = raw.get("match") or {} + if not isinstance(match_block, dict) or not match_block: + raise ValueError(f"rule #{i}: 'match' must be a non-empty mapping") + matchers = tuple(build_matcher(k, v) for k, v in match_block.items()) + rules.append(RoutingRule(target=target, matchers=matchers)) + + logger.info( + "Router built: default=%s rules=%d (operators=%s)", + default, + len(rules), + sorted(loaded_operators), + ) + return cls(default=default, rules=rules) diff --git a/rock/sandbox/sandbox_manager.py b/rock/sandbox/sandbox_manager.py index 1c1ff3070a..d0b0403397 100644 --- a/rock/sandbox/sandbox_manager.py +++ b/rock/sandbox/sandbox_manager.py @@ -33,6 +33,8 @@ from rock.sandbox import __version__ as gateway_version from rock.sandbox.base_manager import BaseManager from rock.sandbox.operator.abstract import AbstractOperator +from rock.sandbox.operator.registry import OperatorRegistry +from rock.sandbox.operator.routing import RouteContext from rock.sandbox.sandbox_actor import SandboxActor from rock.sandbox.sandbox_meta_store import SandboxMetaStore from rock.sandbox.sandbox_statemachine import SandboxStateMachine @@ -54,10 +56,10 @@ def __init__( self, rock_config: RockConfig, meta_store: SandboxMetaStore, + registry: OperatorRegistry, ray_namespace: str = env_vars.ROCK_RAY_NAMESPACE, ray_service: RayService | None = None, enable_runtime_auto_clear: bool = False, - operator: AbstractOperator | None = None, ): super().__init__( rock_config, @@ -66,10 +68,40 @@ def __init__( ) self._ray_service = ray_service self._ray_namespace = ray_namespace - self._operator = operator + if not registry.loaded_names: + raise ValueError("SandboxManager requires a non-empty OperatorRegistry") + self._registry = registry self._aes_encrypter = AESEncryption() self._proxy_service = SandboxProxyService(rock_config=rock_config, meta_store=meta_store) - logger.info("sandbox service init success") + logger.info( + "sandbox service init success: operators=%s default=%s", + sorted(self._registry.loaded_names), + self._registry.default_name, + ) + + async def _get_operator_for_sandbox(self, sandbox_id: str) -> AbstractOperator: + """Resolve the operator that owns ``sandbox_id`` via meta.operator_name. + + GET / stop / restart / delete must hit the same operator that + originally submitted the sandbox. Falls back to the registry default + when the meta record is missing or pre-dates the multi-operator + rollout (operator_name absent). + """ + info = await self._meta_store.get(sandbox_id, check_db=True) + name = (info or {}).get("operator_name") if info else None + if not name: + return self._registry.get(self._registry.default_name) + try: + return self._registry.get(name) + except KeyError: + logger.warning( + "meta operator_name=%s for sandbox=%s is not loaded; falling back to default=%s", + name, + sandbox_id, + self._registry.default_name, + ) + return self._registry.get(self._registry.default_name) + async def _get_current_statemachine(self, sandbox_id: str) -> SandboxStateMachine | None: """Fetch current state from meta store and return a restored SandboxStateMachine, or None if not found.""" @@ -135,7 +167,12 @@ async def start_async( docker_deployment_config.cpus = self.rock_config.runtime.standard_spec.cpus docker_deployment_config.memory = self.rock_config.runtime.standard_spec.memory with StageTimer("startup_timing", f"[{sandbox_id}] Operator submit", logger): - sandbox_info: SandboxInfo = await self._operator.submit(docker_deployment_config, user_info) + route_ctx = RouteContext.from_deployment(docker_deployment_config, user_info, cluster_info) + operator_name, operator = self._registry.resolve(route_ctx) + sandbox_info: SandboxInfo = await operator.submit(docker_deployment_config, user_info) + # Bind sandbox → operator once, so all downstream GET/stop/delete + # dispatch via meta lookup hit the same operator instance. + sandbox_info["operator_name"] = operator_name await self._build_sandbox_info_metadata(sandbox_info, user_info, cluster_info) timeout_info = SandboxTimeoutHelper.make_timeout_info(docker_deployment_config.auto_clear_time) with StageTimer("startup_timing", f"[{sandbox_id}] Meta store create", logger): @@ -161,10 +198,11 @@ async def restart_async(self, sandbox_id: str) -> SandboxStartResponse: if state != State.STOPPED: raise BadRequestRockError(f"Sandbox {sandbox_id} cannot be restarted: current state is '{state.value}'") + operator = await self._get_operator_for_sandbox(sandbox_id) await sm.send( "restart", sandbox_id=sandbox_id, - operator=self._operator, + operator=operator, meta_store=self._meta_store, ) @@ -193,10 +231,11 @@ async def start(self, config: DeploymentConfig) -> SandboxStartResponse: @monitor_sandbox_operation() async def stop(self, sandbox_id: str, reason: StopReason = StopReason.MANUAL): sm = await self._get_current_statemachine(sandbox_id) + operator = await self._get_operator_for_sandbox(sandbox_id) if sm is None: logger.info(f"stop dangling sandbox {sandbox_id}") try: - await self._operator.stop(sandbox_id, reason=reason) + await operator.stop(sandbox_id, reason=reason) except ValueError as e: logger.error(f"ray get actor, actor {sandbox_id} not exist", exc_info=e) elif sm.current_state.value == State.STOPPED: @@ -205,7 +244,7 @@ async def stop(self, sandbox_id: str, reason: StopReason = StopReason.MANUAL): await sm.send( "stop", sandbox_id=sandbox_id, - operator=self._operator, + operator=operator, meta_store=self._meta_store, reason=reason, ) @@ -218,7 +257,7 @@ async def stop(self, sandbox_id: str, reason: StopReason = StopReason.MANUAL): await sm.send( "delete", sandbox_id=sandbox_id, - operator=self._operator, + operator=operator, meta_store=self._meta_store, reason=DeleteReason.IMMEDIATE, ) @@ -237,10 +276,11 @@ async def delete(self, sandbox_id: str, reason: DeleteReason = DeleteReason.MANU raise BadRequestRockError( f"Sandbox {sandbox_id} cannot be deleted: current state is '{state.value}', must be stopped first" ) + operator = await self._get_operator_for_sandbox(sandbox_id) await sm.send( "delete", sandbox_id=sandbox_id, - operator=self._operator, + operator=operator, meta_store=self._meta_store, reason=reason, ) @@ -279,7 +319,8 @@ async def get_status(self, sandbox_id, include_all_states: bool = False) -> Sand # update status from operator is_alive = False - operator_sandbox_info: SandboxInfo | None = await self._operator.get_status(sandbox_id=sandbox_id) + operator = await self._get_operator_for_sandbox(sandbox_id) + operator_sandbox_info: SandboxInfo | None = await operator.get_status(sandbox_id=sandbox_id) if operator_sandbox_info is not None: is_alive = operator_sandbox_info.get("state") == State.RUNNING if sm.current_state.value == State.PENDING and is_alive: diff --git a/rock/sandbox/service/sandbox_proxy_service.py b/rock/sandbox/service/sandbox_proxy_service.py index e42564a7bd..ee1de32469 100644 --- a/rock/sandbox/service/sandbox_proxy_service.py +++ b/rock/sandbox/service/sandbox_proxy_service.py @@ -34,7 +34,7 @@ from rock.admin.proto.request import SandboxReadFileRequest as ReadFileRequest from rock.admin.proto.request import SandboxWriteFileRequest as WriteFileRequest from rock.admin.proto.response import SandboxListResponse, SandboxListStatusResponse, SandboxStatusResponse -from rock.config import OssConfig, ProxyServiceConfig, RockConfig +from rock.config import OssConfig, ProxyServiceConfig, RemoteConfig, RockConfig from rock.deployments.constants import Port from rock.deployments.status import ServiceStatus from rock.common.port_validation import validate_port_forward_port @@ -61,6 +61,11 @@ def __init__(self, rock_config: RockConfig, meta_store: SandboxMetaStore): ) self.oss_config: OssConfig = rock_config.oss self.proxy_config: ProxyServiceConfig = rock_config.proxy_service + # Optional remote config; when set, sandbox_info entries carrying + # ``extended_params['remote.sandbox_domain']`` are routed via the + # remote data-plane proxy instead of the local host_ip/port + # (Addressing Layer). + self._remote_config: RemoteConfig | None = rock_config.remote logger.info(f"proxy config: {self.proxy_config}") # Initialize httpx client with configuration self._httpx_client = httpx.AsyncClient( @@ -69,6 +74,7 @@ def __init__(self, rock_config: RockConfig, meta_store: SandboxMetaStore): max_connections=self.proxy_config.max_connections, max_keepalive_connections=self.proxy_config.max_keepalive_connections, ), + verify=False, ) # Replace single self.sts_client with a dict keyed by account name, @@ -630,7 +636,13 @@ async def _forward_tcp_to_websocket(self, reader, websocket, direction: str, idl async def get_service_status(self, sandbox_id: str): sandbox_info = await self._meta_store.get(sandbox_id) - if not sandbox_info or sandbox_info.get("host_ip") is None: + # Remote-managed sandboxes have no local host_ip; they expose a + # ``remote.sandbox_domain`` in extended_params instead. Accept either + # as the addressing key. + ext = (sandbox_info or {}).get("extended_params") or {} + if not sandbox_info or ( + sandbox_info.get("host_ip") is None and not ext.get("remote.sandbox_domain") + ): raise Exception(f"sandbox {sandbox_id} not started") return [sandbox_info] @@ -644,10 +656,17 @@ async def _send_request( files: dict | None, method: str, ): - host_ip = sandbox_status_dict.get("host_ip") service_status = ServiceStatus.from_dict(sandbox_status_dict) - api_url = self._api_url(host_ip, service_status) headers = self._headers(sandbox_id) + # Addressing Layer: prefer remote data-plane URL when available; + # fall back to host_ip/port for local (ray/k8s) sandboxes. + remote = self._resolve_remote_data_plane(sandbox_status_dict) + if remote is not None: + api_url, extra_headers = remote + headers.update(extra_headers) + else: + host_ip = sandbox_status_dict.get("host_ip") + api_url = self._api_url(host_ip, service_status) logger.info(f"headers: {headers}") full_request_url = f"{api_url}/{path}" logger.info(f"full_request_url: {full_request_url}") @@ -668,6 +687,18 @@ async def _send_request( return {"exit_code": -1, "failure_reason": response.json()["rockletexception"]["message"]} if response.status_code == HTTP_504_GATEWAY_TIMEOUT: return {"exit_code": -1, "failure_reason": response.json()["detail"]} + # Remote data-plane goes through an HTTP gateway (e.g. tengine + # ingress) that may return non-JSON HTML on 4xx/5xx. Surface a + # clear error rather than a confusing JSONDecodeError. Ray/K8s + # path is unaffected (rocklet always returns JSON). + if remote is not None and response.status_code >= 400: + body_text = response.text[:200] + logger.error( + f"Upstream returned HTTP {response.status_code} for {full_request_url}: {body_text}" + ) + raise Exception( + f"Upstream error {response.status_code} from sandbox data-plane: {body_text}" + ) return response.json() except httpx.RequestError as e: # Handle network-level errors, such as DNS resolution failure, connection timeout, etc. @@ -682,6 +713,32 @@ def _api_url(self, host_ip: str, service_status: ServiceStatus) -> str: port = service_status.get_mapped_port(Port.PROXY) return f"http://{host_ip}:{port}" + def _resolve_remote_data_plane( + self, sandbox_status_dict: dict + ) -> tuple[str, dict[str, str]] | None: + """Return (base_url, extra_headers) for remote sandbox data-plane proxy.""" + if not self._remote_config: + return None + ext = sandbox_status_dict.get("extended_params") or {} + sandbox_domain = ext.get("remote.sandbox_domain") + if not sandbox_domain: + return None + remote_id = sandbox_status_dict.get("host_name") or ext.get("remote.sandbox_id") + if not remote_id: + logger.warning( + "remote sandbox missing remote id: keys=%s", list(sandbox_status_dict.keys()) + ) + return None + base_url = f"http://{sandbox_domain}" + headers: dict[str, str] = { + self._remote_config.header_sandbox_id: remote_id, + self._remote_config.header_sandbox_port: str(self._remote_config.rocklet_port), + } + traffic_token = ext.get("remote.traffic_access_token") + if traffic_token: + headers["X-Traffic-Token"] = traffic_token + return base_url, headers + def gen_oss_sts_token( self, account: str = "legacy" ) -> dict | None: # CHANGED: account param, default "legacy" preserves BC diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 9e4a5172b2..d67f61e234 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -20,6 +20,7 @@ from rock.sandbox.operator.k8s.operator import K8sOperator from rock.sandbox.operator.k8s.template_loader import K8sTemplateLoader from rock.sandbox.operator.ray import RayOperator +from rock.sandbox.operator.registry import OperatorRegistry from rock.sandbox.sandbox_manager import SandboxManager from rock.sandbox.sandbox_meta_store import SandboxMetaStore from rock.sandbox.service.sandbox_proxy_service import SandboxProxyService @@ -107,13 +108,16 @@ async def sandbox_manager( meta_store = SandboxMetaStore( redis_provider=redis_provider, sandbox_table=_memory_sandbox_table, rock_config=rock_config ) + operator_name = rock_config.runtime.operator_type or "ray" + registry = OperatorRegistry(default_name=operator_name) + registry.register(operator_name, ray_operator) sandbox_manager = SandboxManager( rock_config, meta_store=meta_store, + registry=registry, ray_namespace=rock_config.ray.namespace, ray_service=ray_service, enable_runtime_auto_clear=rock_config.runtime.enable_auto_clear, - operator=ray_operator, ) return sandbox_manager diff --git a/tests/unit/sandbox/operator/test_k8s_operator.py b/tests/unit/sandbox/operator/test_k8s_operator.py index d0dcb784fa..93c0151919 100644 --- a/tests/unit/sandbox/operator/test_k8s_operator.py +++ b/tests/unit/sandbox/operator/test_k8s_operator.py @@ -244,6 +244,7 @@ async def get_current_statemachine(sandbox_id: str): return await SandboxStateMachine.from_state_value(info.get("state"), sandbox_info=info) m._get_current_statemachine = AsyncMock(side_effect=get_current_statemachine) + m._get_operator_for_sandbox = AsyncMock(return_value=k8s_operator) m.get_status = SandboxManager.get_status.__get__(m, SandboxManager) m._refresh_timeout = AsyncMock() return m diff --git a/tests/unit/sandbox/test_get_status_include_all_states.py b/tests/unit/sandbox/test_get_status_include_all_states.py index 1a25340c2b..15cb2c03ba 100644 --- a/tests/unit/sandbox/test_get_status_include_all_states.py +++ b/tests/unit/sandbox/test_get_status_include_all_states.py @@ -56,6 +56,7 @@ async def sandbox_manager(mock_operator, mock_meta_store, rock_config): mock_sm = await SandboxStateMachine.from_state_value(State.PENDING, sandbox_info={}) manager._get_current_statemachine = AsyncMock(return_value=mock_sm) + manager._get_operator_for_sandbox = AsyncMock(return_value=mock_operator) return manager diff --git a/tests/unit/sandbox/test_sandbox_manager_delete.py b/tests/unit/sandbox/test_sandbox_manager_delete.py index c411196baa..2019a0607d 100644 --- a/tests/unit/sandbox/test_sandbox_manager_delete.py +++ b/tests/unit/sandbox/test_sandbox_manager_delete.py @@ -13,6 +13,7 @@ from rock.actions.sandbox.response import State from rock.common.constants import DeleteReason from rock.config import RockConfig, SandboxConfig +from rock.sandbox.operator.registry import OperatorRegistry from rock.sandbox.sandbox_manager import SandboxManager from rock.sdk.common.exceptions import BadRequestRockError @@ -27,6 +28,8 @@ def rock_config_min(): @pytest.fixture def manager(rock_config_min): operator = AsyncMock() + registry = OperatorRegistry(default_name="test") + registry.register("test", operator) meta_store = AsyncMock() meta_store.get = AsyncMock(return_value=None) # Patch BaseManager scheduler setup so tests don't spawn APScheduler. @@ -34,11 +37,16 @@ def manager(rock_config_min): m = SandboxManager( rock_config=rock_config_min, meta_store=meta_store, + registry=registry, ray_namespace="test", ray_service=MagicMock(), enable_runtime_auto_clear=False, - operator=operator, ) + # Force every dispatch through the same mock operator regardless of the + # ``operator_name`` recorded on the meta record, and expose it under + # ``manager._operator`` so assertions stay readable. + m._operator = operator + m._get_operator_for_sandbox = AsyncMock(return_value=operator) return m diff --git a/tests/unit/sandbox/test_sandbox_transitions.py b/tests/unit/sandbox/test_sandbox_transitions.py index ce2756d412..c318920c2a 100644 --- a/tests/unit/sandbox/test_sandbox_transitions.py +++ b/tests/unit/sandbox/test_sandbox_transitions.py @@ -59,6 +59,7 @@ async def get_current_statemachine(sandbox_id: str) -> SandboxStateMachine | Non return await SandboxStateMachine.from_state_value(info.get("state"), sandbox_info=info) m._get_current_statemachine = AsyncMock(side_effect=get_current_statemachine) + m._get_operator_for_sandbox = AsyncMock(return_value=mock_operator) m.stop = SandboxManager.stop.__get__(m, SandboxManager) m.get_status = SandboxManager.get_status.__get__(m, SandboxManager) @@ -318,6 +319,9 @@ def mgr_start(mgr, mock_meta_store, mock_operator, mock_docker_config): mgr.rock_config = MagicMock() mgr.rock_config.runtime.use_standard_spec_only = False + mgr._registry = MagicMock() + mgr._registry.resolve = MagicMock(return_value=("docker", mock_operator)) + mgr.start_async = SandboxManager.start_async.__wrapped__.__get__(mgr) mgr._build_sandbox_info_metadata = AsyncMock() mgr.start = SandboxManager.start.__wrapped__.__get__(mgr) diff --git a/uv.lock b/uv.lock index eb6e373dfc..d71682ef87 100644 --- a/uv.lock +++ b/uv.lock @@ -4280,7 +4280,7 @@ wheels = [ [[package]] name = "rl-rock" -version = "1.8.0" +version = "1.8.3" source = { editable = "." } dependencies = [ { name = "anyio" },