diff --git a/checkpoint/orbax/checkpoint/experimental/tiering_service/db_lib.py b/checkpoint/orbax/checkpoint/experimental/tiering_service/db_lib.py index f4dc2e1f9..48df80cce 100644 --- a/checkpoint/orbax/checkpoint/experimental/tiering_service/db_lib.py +++ b/checkpoint/orbax/checkpoint/experimental/tiering_service/db_lib.py @@ -15,8 +15,13 @@ """Database initialization utilities for Tiering Service.""" import contextlib +import sqlite3 + from orbax.checkpoint.experimental.tiering_service import db_schema from orbax.checkpoint.experimental.tiering_service.proto import tiering_service_pb2 +from sqlalchemy import event +from sqlalchemy.dialects.sqlite.aiosqlite import AsyncAdapt_aiosqlite_connection +from sqlalchemy.engine import Engine from sqlalchemy.exc import OperationalError from sqlalchemy.ext.asyncio import AsyncEngine from sqlalchemy.ext.asyncio import AsyncSession @@ -25,6 +30,29 @@ from sqlalchemy.orm import sessionmaker +@event.listens_for(Engine, "connect") +def set_sqlite_pragma(dbapi_connection, connection_record): + """Enables foreign key constraints on SQLite database connections. + + This is SQLite-specific because other databases (like PostgreSQL) enforce + foreign keys by default and do not support SQLite's PRAGMA syntax. + We perform an isinstance check against the standard sqlite3.Connection + and SQLAlchemy's aiosqlite adapter wrapper to verify if this is an SQLite + connection. + + Args: + dbapi_connection: The database connection to configure. + connection_record: Metadata about the connection. + """ + del connection_record + connection_types = (sqlite3.Connection, AsyncAdapt_aiosqlite_connection) + + if isinstance(dbapi_connection, connection_types): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + def get_async_engine(config: tiering_service_pb2.ServerConfig) -> AsyncEngine: """Returns an AsyncEngine configured from ServerConfig.""" input_url = config.db_connection_str diff --git a/checkpoint/orbax/checkpoint/experimental/tiering_service/db_schema.py b/checkpoint/orbax/checkpoint/experimental/tiering_service/db_schema.py index 9f0e75c1c..2f9474987 100644 --- a/checkpoint/orbax/checkpoint/experimental/tiering_service/db_schema.py +++ b/checkpoint/orbax/checkpoint/experimental/tiering_service/db_schema.py @@ -63,6 +63,17 @@ class RequestType(enum.IntEnum): REQUEST_TYPE_DELETE_FROM_ALL_TIERS = 3 +class TierPathState(enum.IntEnum): + """The state of an asset's storage location (tier path).""" + + UNSPECIFIED = 0 + PENDING = 1 + IN_PROGRESS = 2 + READY = 3 + FAILED = 4 + DELETED = 5 + + class Asset(Base): """A CTS asset representing a complete checkpoint. @@ -331,6 +342,11 @@ class TierPath(Base): nullable=False, default=lambda: str(uuid.uuid4()), ) + state = sqlalchemy.Column( + sqlalchemy.Enum(TierPathState), + default=TierPathState.PENDING, + nullable=False, + ) asset = sqlalchemy.orm.relationship("Asset", back_populates="tier_paths") storage_backend = sqlalchemy.orm.relationship( @@ -338,11 +354,23 @@ class TierPath(Base): ) __table_args__ = ( - # An asset can have at most one TierPath for a given storage backend. - sqlalchemy.UniqueConstraint( + # Enforce uniqueness of (asset, backend) only for active tier paths + # (PENDING, IN_PROGRESS, READY). + sqlalchemy.Index( + "uq_tier_path_active_backend", "asset_uuid", "storage_backend_id", - name="uq_tier_path_asset_backend", + unique=True, + sqlite_where=sqlalchemy.column("state").in_([ + TierPathState.PENDING.name, + TierPathState.IN_PROGRESS.name, + TierPathState.READY.name, + ]), + postgresql_where=sqlalchemy.column("state").in_([ + TierPathState.PENDING.name, + TierPathState.IN_PROGRESS.name, + TierPathState.READY.name, + ]), ), ) @@ -350,7 +378,8 @@ def __repr__(self): return ( f"TierPath(id={self.id}, asset_uuid='{self.asset_uuid}'," f" storage_backend_id={self.storage_backend_id}, path='{self.path}'," - f" ready_at={self.ready_at}, expires_at={self.expires_at})" + f" state={self.state.name}, ready_at={self.ready_at}," + f" expires_at={self.expires_at})" ) @@ -367,6 +396,13 @@ class AssetJob(Base): status: Current execution status of the job, an instance of JobStatus. target_tier_path_id: Foreign key to the targeted TierPath for operations such as COPY or DELETE_FROM_INSTANCE. + request_id: A unique identifier (UUID) for this job execution request. + transfer_status: JSON dictionary containing progress and GCP operation + details. + expiration_at: Timestamp when the worker's lease on this job expires. + last_updated_at: Timestamp of the last status update or heartbeat. + worker_host: Hostname of the worker processing this job. + worker_pid: Process ID of the worker processing this job. created_at: Timestamp when the job was created. completed_at: Timestamp when the job was completed. asset: Relationship to the associated Asset. @@ -396,9 +432,24 @@ class AssetJob(Base): # Target tier path for COPY and DELETE_FROM_INSTANCE requests target_tier_path_id = sqlalchemy.Column( sqlalchemy.Integer, - sqlalchemy.ForeignKey("tier_paths.id", ondelete="CASCADE"), + sqlalchemy.ForeignKey("tier_paths.id"), nullable=True, ) + request_id = sqlalchemy.Column( + sqlalchemy.String, + nullable=False, + unique=True, + default=lambda: str(uuid.uuid4()), + ) + transfer_status = sqlalchemy.Column(sqlalchemy.JSON, nullable=True) + expiration_at = sqlalchemy.Column( + sqlalchemy.DateTime(timezone=True), nullable=True + ) + last_updated_at = sqlalchemy.Column( + sqlalchemy.DateTime(timezone=True), nullable=True + ) + worker_host = sqlalchemy.Column(sqlalchemy.String, nullable=True) + worker_pid = sqlalchemy.Column(sqlalchemy.Integer, nullable=True) created_at = sqlalchemy.Column( sqlalchemy.DateTime(timezone=True), @@ -416,9 +467,11 @@ class AssetJob(Base): # target_tier_path is required in COPY and DELETE_FROM_INSTANCE requests. sqlalchemy.CheckConstraint( """ - (request_type IN ('REQUEST_TYPE_COPY', 'REQUEST_TYPE_DELETE_FROM_INSTANCE') AND target_tier_path_id IS NOT NULL) + (request_type IN ('REQUEST_TYPE_COPY', 'REQUEST_TYPE_DELETE_FROM_INSTANCE') + AND target_tier_path_id IS NOT NULL) OR - (request_type IN ('REQUEST_TYPE_DELETE_FROM_ALL_TIERS', 'REQUEST_TYPE_UNSPECIFIED') AND target_tier_path_id IS NULL) + (request_type IN ('REQUEST_TYPE_DELETE_FROM_ALL_TIERS', 'REQUEST_TYPE_UNSPECIFIED') + AND target_tier_path_id IS NULL) """, name="check_asset_job_valid_payload", ), @@ -430,5 +483,6 @@ def __repr__(self): f" request_type='{self.request_type.name}'," f" status='{self.status.name}'," f" target_tier_path_id={self.target_tier_path_id}," + f" request_id='{self.request_id}'," f" created_at={self.created_at}, completed_at={self.completed_at})" ) diff --git a/checkpoint/orbax/checkpoint/experimental/tiering_service/db_schema_test.py b/checkpoint/orbax/checkpoint/experimental/tiering_service/db_schema_test.py index 2421e7cad..29cf794e7 100644 --- a/checkpoint/orbax/checkpoint/experimental/tiering_service/db_schema_test.py +++ b/checkpoint/orbax/checkpoint/experimental/tiering_service/db_schema_test.py @@ -22,12 +22,22 @@ import greenlet # pylint: disable=unused-import from orbax.checkpoint.experimental.tiering_service import db_schema import sqlalchemy +from sqlalchemy import event +from sqlalchemy.engine import Engine from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.future import select from sqlalchemy.orm import sessionmaker +@event.listens_for(Engine, "connect") +def set_sqlite_pragma(dbapi_connection, connection_record): + del connection_record + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + class DbSchemaTest(parameterized.TestCase, unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self) -> None: @@ -442,6 +452,115 @@ async def test_asset_job_queue(self) -> None: fetched_job.status, db_schema.JobStatus.JOB_STATUS_COMPLETED ) + async def test_tier_path_deletion_fails_when_referenced_by_job(self) -> None: + async with self.session_maker() as session: + asset = db_schema.Asset( + asset_uuid="uuid-delete-tp-fail", + path="/experiment/delete-tp-fail", + user="testuser", + ) + backend = db_schema.StorageBackend( + level=0, + zone="us-central1-a", + backend_type=db_schema.BackendType.BACKEND_TYPE_GCS, + prefix="gs://gcs-bucket", + ) + tier_path = db_schema.TierPath( + asset_uuid="uuid-delete-tp-fail", + storage_backend=backend, + path="/path1", + ) + session.add_all([asset, backend, tier_path]) + await session.commit() + + job = db_schema.AssetJob( + asset_uuid="uuid-delete-tp-fail", + request_type=db_schema.RequestType.REQUEST_TYPE_COPY, + status=db_schema.JobStatus.JOB_STATUS_COMPLETED, + target_tier_path_id=tier_path.id, + ) + session.add(job) + await session.commit() + + # Deleting the tier path should fail with IntegrityError because the job + # still references it. + await session.delete(tier_path) + with self.assertRaises(sqlalchemy.exc.IntegrityError): + await session.commit() + + async def test_tier_path_conditional_uniqueness(self) -> None: + async with self.session_maker() as session: + asset = db_schema.Asset( + asset_uuid="uuid-cond-uniq", + path="/experiment/cond-uniq", + user="testuser", + ) + backend = db_schema.StorageBackend( + level=0, + zone="us-central1-a", + backend_type=db_schema.BackendType.BACKEND_TYPE_GCS, + prefix="gs://gcs-bucket", + ) + session.add_all([asset, backend]) + await session.commit() + + # 1. We can insert one PENDING tier path + tp_pending = db_schema.TierPath( + asset_uuid="uuid-cond-uniq", + storage_backend=backend, + path="/path-pending", + state=db_schema.TierPathState.PENDING, + ) + session.add(tp_pending) + await session.commit() + + # 2. Trying to insert another IN_PROGRESS tier path should fail + tp_in_progress = db_schema.TierPath( + asset_uuid="uuid-cond-uniq", + storage_backend=backend, + path="/path-inprogress", + state=db_schema.TierPathState.IN_PROGRESS, + ) + session.add(tp_in_progress) + with self.assertRaises(sqlalchemy.exc.IntegrityError): + await session.commit() + await session.rollback() + + # 3. Transition the first one to FAILED + result = await session.execute( + select(db_schema.TierPath).filter_by(asset_uuid="uuid-cond-uniq") + ) + tp = result.scalars().first() + tp.state = db_schema.TierPathState.FAILED + await session.commit() + + # 4. Now we can insert a new PENDING tier path for the same backend! + tp_new_pending = db_schema.TierPath( + asset_uuid="uuid-cond-uniq", + storage_backend=backend, + path="/path-new-pending", + state=db_schema.TierPathState.PENDING, + ) + session.add(tp_new_pending) + await session.commit() + + # 5. And we can also insert a DELETED tier path for the same backend! + tp_deleted = db_schema.TierPath( + asset_uuid="uuid-cond-uniq", + storage_backend=backend, + path="/path-deleted", + state=db_schema.TierPathState.DELETED, + ) + session.add(tp_deleted) + await session.commit() + + # Verify all 3 rows exist (FAILED, DELETED, PENDING) + result = await session.execute( + select(db_schema.TierPath).filter_by(asset_uuid="uuid-cond-uniq") + ) + paths = result.scalars().all() + self.assertLen(paths, 3) + async def test_create_asset_duplicates_allowed_for_deleted_incomplete(self): # Verify we can have duplicate path for DELETED or INCOMPLETE states async with self.session_maker() as session: diff --git a/checkpoint/orbax/checkpoint/experimental/tiering_service/gcp_storage_client.py b/checkpoint/orbax/checkpoint/experimental/tiering_service/gcp_storage_client.py new file mode 100644 index 000000000..6971b69ea --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/tiering_service/gcp_storage_client.py @@ -0,0 +1,455 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GCP storage clients for Checkpoint Tiering Service (CTS).""" + +import abc +import asyncio +import dataclasses +import datetime +import enum +import os +from typing import Any +import google.auth +from google.auth import exceptions as auth_exceptions +from google.auth import impersonated_credentials +from google.auth import transport +import httpx + + +class OperationStatus(enum.Enum): + """The status of a job or operation.""" + + IN_PROGRESS = "IN_PROGRESS" + SUCCESS = "SUCCESS" + FAILED = "FAILED" + + +@dataclasses.dataclass +class Result: + status: OperationStatus + detail_info: dict[str, Any] + + +@dataclasses.dataclass +class TransferContext: + job_request_id: str + source_path: str + destination_path: str + transfer_status: dict[str, Any] + + +class HttpxRequest(transport.Request): + """A google-auth compatible transport request using HTTPX (sync).""" + + def __init__(self, client: httpx.Client): + self._client = client + + def __call__( + self, + url: str, + method: str = "GET", + body: bytes | None = None, + headers: dict[str, str] | None = None, + timeout: float | None = None, + **kwargs, + ) -> transport.Response: + try: + response = self._client.request( + method=method, + url=url, + headers=headers, + content=body, + timeout=timeout, + **kwargs, + ) + + class HttpxResponse(transport.Response): + """A google-auth compatible transport response using HTTPX.""" + + def __init__(self, resp): + self._resp = resp + + @property + def status(self) -> int: + return self._resp.status_code + + @property + def headers(self) -> dict[str, str]: + return dict(self._resp.headers) + + @property + def data(self) -> bytes: + return self._resp.content + + return HttpxResponse(response) + except httpx.TimeoutException as e: + raise auth_exceptions.TransportError(f"Timeout: {e}") + except httpx.RequestError as e: + raise auth_exceptions.TransportError(f"Request error: {e}") + + +class GCPStorageClient(abc.ABC): + """Client interface to interact with GCP storage backend (e.g. + + Lustre, GCS). + """ + + def __init__( + self, + project: str | None = None, + location: str | None = None, + instance: str | None = None, + service_account: str | None = None, + ): + self.project = project + self.location = location + self.instance = instance + self.service_account = service_account + self._credentials = None + self._async_client = None + + @property + def async_client(self) -> httpx.AsyncClient: + if self._async_client is None: + self._async_client = httpx.AsyncClient() + return self._async_client + + async def close(self): + if self._async_client is not None: + await self._async_client.aclose() + self._async_client = None + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + + async def _get_token_and_project(self) -> tuple[str, str]: + """Gets authentication credentials and projects.""" + if not self._credentials: + base_credentials, detected_project = await asyncio.to_thread( + google.auth.default, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + if not self.project: + self.project = detected_project + + if self.service_account: + self._credentials = impersonated_credentials.Credentials( + source_credentials=base_credentials, + target_principal=self.service_account, + target_scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + else: + self._credentials = base_credentials + + if not self._credentials.valid: + with httpx.Client() as client: + await asyncio.to_thread(self._credentials.refresh, HttpxRequest(client)) + + if not self.project: + raise ValueError("GCP Project ID must be specified or auto-detected.") + + return self._credentials.token, self.project + + @abc.abstractmethod + async def trigger_copy( + self, + request_id: str, + source_path: str, + destination_path: str, + ) -> str: + """Triggers copy and returns operation name.""" + pass + + @abc.abstractmethod + async def poll_operation( + self, + operation_name: str, + context: TransferContext | None = None, + ) -> Result: + """Polls operation status and returns a Result object.""" + pass + + +def _parse_gcs_path(gcs_path: str) -> tuple[str, str]: + """Parses a GCS path like gs://bucket/prefix/file into (bucket, prefix).""" + path_no_scheme = gcs_path.replace("gs://", "") + parts = path_no_scheme.split("/", 1) + bucket = parts[0] + prefix = parts[1] if len(parts) > 1 else "" + return bucket, prefix + + +class GcsToGcsClient(GCPStorageClient): + """Client implementation for GCS-to-GCS operations using Storage Transfer Service.""" + + def __init__( + self, + project: str | None = None, + service_account: str | None = None, + ): + super().__init__(project=project, service_account=service_account) + + async def trigger_copy( + self, + request_id: str, + source_path: str, + destination_path: str, + ) -> str: + """Triggers GCS-to-GCS transfer using GCP Storage Transfer Service (STS).""" + token, project = await self._get_token_and_project() + url = "https://storagetransfer.googleapis.com/v1/transferJobs" + headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + } + + src_bucket, src_prefix = _parse_gcs_path(source_path) + dest_bucket, dest_prefix = _parse_gcs_path(destination_path) + + now = datetime.datetime.now(datetime.timezone.utc) + payload = { + "projectId": project, + "transferSpec": { + "gcsDataSource": { + "bucketName": src_bucket, + "path": src_prefix, + }, + "gcsDataSink": { + "bucketName": dest_bucket, + "path": dest_prefix, + }, + "transferOptions": { + "overwriteObjectsAlreadyExistingInSink": True, + }, + }, + "schedule": { + "scheduleStartDate": { + "year": now.year, + "month": now.month, + "day": now.day, + }, + }, + "status": "DISABLED", + } + + response = await self.async_client.post(url, json=payload, headers=headers) + + if response.status_code != 200: + raise RuntimeError( + f"Failed to create Storage Transfer Job: {response.status_code} -" + f" {response.text}" + ) + + job_name = response.json()["name"] + + # Run the job immediately. This returns the operation name directly. + run_url = f"https://storagetransfer.googleapis.com/v1/{job_name}:run" + run_payload = {"projectId": project} + run_response = await self.async_client.post( + run_url, json=run_payload, headers=headers + ) + + if run_response.status_code != 200: + raise RuntimeError( + f"Failed to run Storage Transfer Job {job_name}:" + f" {run_response.status_code} - {run_response.text}" + ) + + operation_name = run_response.json()["name"] + return operation_name + + async def poll_operation( + self, + operation_name: str, + context: TransferContext | None = None, + ) -> Result: + """Polls Storage Transfer Service operation status.""" + token, _ = await self._get_token_and_project() + url = f"https://storagetransfer.googleapis.com/v1/{operation_name}" + headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + } + response = await self.async_client.get(url, headers=headers) + if response.status_code != 200: + raise RuntimeError( + f"Failed to poll Storage Transfer operation: {response.status_code} -" + f" {response.text}" + ) + + data = response.json() + done = data.get("done", False) + + if not done: + metadata = data.get("metadata", {}) + counters = metadata.get("counters", {}) + bytes_transferred = int(counters.get("bytesTransferredToSink", 0)) + bytes_found = int(counters.get("bytesFoundToTransfer", 0)) + return Result( + status=OperationStatus.IN_PROGRESS, + detail_info={ + "bytes_copied": bytes_transferred, + "total_bytes": bytes_found, + }, + ) + + if "error" in data: + return Result( + status=OperationStatus.FAILED, + detail_info={"error": data["error"]}, + ) + + metadata = data.get("metadata", {}) + op_status = metadata.get("status") + + if op_status == "SUCCESS": + return Result( + status=OperationStatus.SUCCESS, + detail_info={}, + ) + else: + error_msg = f"STS Operation ended with status: {op_status}" + return Result( + status=OperationStatus.FAILED, + detail_info={"error": error_msg}, + ) + + +class GcpLustreBaseClient(GCPStorageClient): + """Base client interface to interact with GCP Managed Lustre API via REST.""" + + def __init__( + self, + project: str | None = None, + location: str | None = None, + instance: str | None = None, + service_account: str | None = None, + ): + location = location or os.environ.get("CTS_LUSTRE_LOCATION") + instance = instance or os.environ.get("CTS_LUSTRE_INSTANCE") + + if not location or not instance: + raise ValueError("Lustre location and instance must be specified.") + + super().__init__( + project=project, + location=location, + instance=instance, + service_account=service_account, + ) + + async def poll_operation( + self, + operation_name: str, + context: TransferContext | None = None, + ) -> Result: + """Polls operation status and returns a Result object.""" + del context # Unused for Lustre + token, _ = await self._get_token_and_project() + url = f"https://lustre.googleapis.com/v1/{operation_name}" + headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + } + response = await self.async_client.get(url, headers=headers) + if response.status_code != 200: + raise RuntimeError( + f"Failed to poll operation: {response.status_code} - {response.text}" + ) + data = response.json() + done = data.get("done", False) + if done: + if "error" in data: + return Result( + status=OperationStatus.FAILED, + detail_info={"error": data["error"]}, + ) + else: + return Result( + status=OperationStatus.SUCCESS, + detail_info=data.get("response", {}), + ) + else: + return Result( + status=OperationStatus.IN_PROGRESS, + detail_info=data.get("metadata", {}), + ) + + +class GcsToLustreClient(GcpLustreBaseClient): + """Client implementation to trigger GCS-to-Lustre imports.""" + + async def trigger_copy( + self, + request_id: str, + source_path: str, + destination_path: str, + ) -> str: + """Triggers import from GCS to Lustre and returns the Operation name.""" + token, project = await self._get_token_and_project() + url = ( + f"https://lustre.googleapis.com/v1/projects/{project}" + f"/locations/{self.location}/instances/{self.instance}:importData" + ) + headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + } + payload = { + "gcsPath": {"uri": source_path}, + "lustrePath": {"path": destination_path}, + "requestId": request_id, + } + response = await self.async_client.post(url, json=payload, headers=headers) + if response.status_code != 200: + raise RuntimeError( + f"Failed to trigger import: {response.status_code} - {response.text}" + ) + return response.json()["name"] + + +class LustreToGcsClient(GcpLustreBaseClient): + """Client implementation to trigger Lustre-to-GCS exports.""" + + async def trigger_copy( + self, + request_id: str, + source_path: str, + destination_path: str, + ) -> str: + """Triggers export from Lustre to GCS and returns the Operation name.""" + token, project = await self._get_token_and_project() + url = ( + f"https://lustre.googleapis.com/v1/projects/{project}" + f"/locations/{self.location}/instances/{self.instance}:exportData" + ) + headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + } + payload = { + "lustrePath": {"path": source_path}, + "gcsPath": {"uri": destination_path}, + "requestId": request_id, + } + response = await self.async_client.post(url, json=payload, headers=headers) + if response.status_code != 200: + raise RuntimeError( + f"Failed to trigger export: {response.status_code} - {response.text}" + ) + return response.json()["name"] diff --git a/checkpoint/orbax/checkpoint/experimental/tiering_service/gcp_storage_client_test.py b/checkpoint/orbax/checkpoint/experimental/tiering_service/gcp_storage_client_test.py new file mode 100644 index 000000000..99caea177 --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/tiering_service/gcp_storage_client_test.py @@ -0,0 +1,281 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for GCP Storage Clients.""" + +import os +import unittest +from unittest import mock +import httpx +from orbax.checkpoint.experimental.tiering_service import gcp_storage_client + + +class GCPStorageClientTest(unittest.IsolatedAsyncioTestCase): + + def setUp(self): + super().setUp() + self.mock_creds = mock.MagicMock() + self.mock_creds.valid = True + self.mock_creds.token = "dummy_auth_token" + self.mock_default_auth = mock.patch( + "google.auth.default", return_value=(self.mock_creds, "dummy-project") + ) + self.mock_default_auth.start() + + # Mock httpx.AsyncClient + self.mock_client = mock.AsyncMock(spec=httpx.AsyncClient) + self.mock_client_patcher = mock.patch( + "httpx.AsyncClient", return_value=self.mock_client + ) + self.mock_client_patcher.start() + + def tearDown(self): + self.mock_default_auth.stop() + self.mock_client_patcher.stop() + super().tearDown() + + async def test_gcs_to_gcs_trigger_copy_success(self): + client = gcp_storage_client.GcsToGcsClient(project="test-project") + + # 1. Mock the first POST call to create transfer job + mock_post_resp_1 = mock.MagicMock(spec=httpx.Response) + mock_post_resp_1.status_code = 200 + mock_post_resp_1.json.return_value = {"name": "transferJobs/job-123"} + + # 2. Mock the second POST call to run transfer job + mock_post_resp_2 = mock.MagicMock(spec=httpx.Response) + mock_post_resp_2.status_code = 200 + mock_post_resp_2.json.return_value = {"name": "transferOperations/op-456"} + + self.mock_client.post.side_effect = [mock_post_resp_1, mock_post_resp_2] + + op_name = await client.trigger_copy( + request_id="req-1", + source_path="gs://src-bucket/path/to/src", + destination_path="gs://dest-bucket/path/to/dest", + ) + + self.assertEqual(op_name, "transferOperations/op-456") + self.assertEqual(self.mock_client.post.call_count, 2) + + async def test_gcs_to_gcs_trigger_copy_sts_fail(self): + client = gcp_storage_client.GcsToGcsClient(project="test-project") + + mock_post_resp = mock.MagicMock(spec=httpx.Response) + mock_post_resp.status_code = 400 + mock_post_resp.text = "Invalid arguments" + self.mock_client.post.return_value = mock_post_resp + + with self.assertRaises(RuntimeError) as ctx: + await client.trigger_copy( + request_id="req-1", + source_path="gs://src-bucket/path", + destination_path="gs://dest-bucket/path", + ) + self.assertIn("Failed to create Storage Transfer Job", str(ctx.exception)) + + async def test_gcs_to_gcs_poll_operation_in_progress(self): + client = gcp_storage_client.GcsToGcsClient(project="test-project") + + mock_get_resp = mock.MagicMock(spec=httpx.Response) + mock_get_resp.status_code = 200 + mock_get_resp.json.return_value = { + "done": False, + "metadata": { + "counters": { + "bytesTransferredToSink": "500", + "bytesFoundToTransfer": "1000", + } + }, + } + self.mock_client.get.return_value = mock_get_resp + + result = await client.poll_operation("transferOperations/op-456") + self.assertEqual( + result.status, gcp_storage_client.OperationStatus.IN_PROGRESS + ) + self.assertEqual(result.detail_info["bytes_copied"], 500) + self.assertEqual(result.detail_info["total_bytes"], 1000) + + async def test_gcs_to_gcs_poll_operation_success(self): + client = gcp_storage_client.GcsToGcsClient(project="test-project") + + mock_get_resp = mock.MagicMock(spec=httpx.Response) + mock_get_resp.status_code = 200 + mock_get_resp.json.return_value = { + "done": True, + "metadata": {"status": "SUCCESS"}, + } + self.mock_client.get.return_value = mock_get_resp + + result = await client.poll_operation("transferOperations/op-456") + self.assertEqual(result.status, gcp_storage_client.OperationStatus.SUCCESS) + + async def test_gcs_to_gcs_poll_operation_failed(self): + client = gcp_storage_client.GcsToGcsClient(project="test-project") + + mock_get_resp = mock.MagicMock(spec=httpx.Response) + mock_get_resp.status_code = 200 + mock_get_resp.json.return_value = { + "done": True, + "metadata": {"status": "FAILED"}, + } + self.mock_client.get.return_value = mock_get_resp + + result = await client.poll_operation("transferOperations/op-456") + self.assertEqual(result.status, gcp_storage_client.OperationStatus.FAILED) + self.assertIn("error", result.detail_info) + + @mock.patch.dict( + os.environ, + { + "CTS_LUSTRE_LOCATION": "us-central1-a", + "CTS_LUSTRE_INSTANCE": "lustre-1", + }, + ) + async def test_gcs_to_lustre_trigger_copy(self): + client = gcp_storage_client.GcsToLustreClient(project="test-project") + + mock_post_resp = mock.MagicMock(spec=httpx.Response) + mock_post_resp.status_code = 200 + mock_post_resp.json.return_value = {"name": "operations/import-123"} + self.mock_client.post.return_value = mock_post_resp + + op_name = await client.trigger_copy( + request_id="req-1", + source_path="gs://src-bucket/path", + destination_path="/lustre/path", + ) + self.assertEqual(op_name, "operations/import-123") + self.mock_client.post.assert_called_once() + + @mock.patch.dict( + os.environ, + { + "CTS_LUSTRE_LOCATION": "us-central1-a", + "CTS_LUSTRE_INSTANCE": "lustre-1", + }, + ) + async def test_lustre_to_gcs_trigger_copy(self): + client = gcp_storage_client.LustreToGcsClient(project="test-project") + + mock_post_resp = mock.MagicMock(spec=httpx.Response) + mock_post_resp.status_code = 200 + mock_post_resp.json.return_value = {"name": "operations/export-123"} + self.mock_client.post.return_value = mock_post_resp + + op_name = await client.trigger_copy( + request_id="req-1", + source_path="/lustre/path", + destination_path="gs://dest-bucket/path", + ) + self.assertEqual(op_name, "operations/export-123") + self.mock_client.post.assert_called_once() + + @mock.patch.dict( + os.environ, + { + "CTS_LUSTRE_LOCATION": "us-central1-a", + "CTS_LUSTRE_INSTANCE": "lustre-1", + }, + ) + async def test_lustre_poll_operation_done_success(self): + client = gcp_storage_client.GcsToLustreClient(project="test-project") + + mock_get_resp = mock.MagicMock(spec=httpx.Response) + mock_get_resp.status_code = 200 + mock_get_resp.json.return_value = { + "done": True, + "response": {"some_metadata": "val"}, + } + self.mock_client.get.return_value = mock_get_resp + + result = await client.poll_operation("operations/import-123") + self.assertEqual(result.status, gcp_storage_client.OperationStatus.SUCCESS) + self.assertEqual(result.detail_info, {"some_metadata": "val"}) + + @mock.patch.dict( + os.environ, + { + "CTS_LUSTRE_LOCATION": "us-central1-a", + "CTS_LUSTRE_INSTANCE": "lustre-1", + }, + ) + async def test_lustre_poll_operation_done_fail(self): + client = gcp_storage_client.GcsToLustreClient(project="test-project") + + mock_get_resp = mock.MagicMock(spec=httpx.Response) + mock_get_resp.status_code = 200 + mock_get_resp.json.return_value = { + "done": True, + "error": {"message": "import failed"}, + } + self.mock_client.get.return_value = mock_get_resp + + result = await client.poll_operation("operations/import-123") + self.assertEqual(result.status, gcp_storage_client.OperationStatus.FAILED) + self.assertEqual(result.detail_info["error"], {"message": "import failed"}) + + @mock.patch.dict( + os.environ, + { + "CTS_LUSTRE_LOCATION": "us-central1-a", + "CTS_LUSTRE_INSTANCE": "lustre-1", + }, + ) + async def test_lustre_poll_operation_in_progress(self): + client = gcp_storage_client.GcsToLustreClient(project="test-project") + + mock_get_resp = mock.MagicMock(spec=httpx.Response) + mock_get_resp.status_code = 200 + mock_get_resp.json.return_value = { + "done": False, + "metadata": {"percent_complete": 42}, + } + self.mock_client.get.return_value = mock_get_resp + + result = await client.poll_operation("operations/import-123") + self.assertEqual( + result.status, gcp_storage_client.OperationStatus.IN_PROGRESS + ) + self.assertEqual(result.detail_info, {"percent_complete": 42}) + + @mock.patch("google.auth.impersonated_credentials.Credentials") + async def test_service_account_token_impersonation( + self, mock_impersonated_creds_class + ): + mock_imp_creds = mock.MagicMock() + mock_imp_creds.valid = True + mock_imp_creds.token = "impersonated_token" + mock_impersonated_creds_class.return_value = mock_imp_creds + + client = gcp_storage_client.GcsToGcsClient( + project="test-project", + service_account="sa@test.iam.gserviceaccount.com", + ) + + token, project = await client._get_token_and_project() + self.assertEqual(token, "impersonated_token") + self.assertEqual(project, "test-project") + + # Verify that impersonated credentials were constructed correctly + mock_impersonated_creds_class.assert_called_once_with( + source_credentials=self.mock_creds, + target_principal="sa@test.iam.gserviceaccount.com", + target_scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/checkpoint/orbax/checkpoint/experimental/tiering_service/job_worker.py b/checkpoint/orbax/checkpoint/experimental/tiering_service/job_worker.py new file mode 100644 index 000000000..43d1c6bbd --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/tiering_service/job_worker.py @@ -0,0 +1,572 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Checkpoint Tiering Service (CTS) worker to handle AssetJobs. + +Consumes queued AssetJobs and manages the asynchronous data movement (eg. Lustre +- GCS import/export). +""" + +import asyncio +import datetime +import os +import socket +from typing import Any +from absl import logging +from orbax.checkpoint.experimental.tiering_service import assets +from orbax.checkpoint.experimental.tiering_service import db_schema +from orbax.checkpoint.experimental.tiering_service.gcp_storage_client import GCPStorageClient +from orbax.checkpoint.experimental.tiering_service.gcp_storage_client import GcsToGcsClient +from orbax.checkpoint.experimental.tiering_service.gcp_storage_client import GcsToLustreClient +from orbax.checkpoint.experimental.tiering_service.gcp_storage_client import LustreToGcsClient +from orbax.checkpoint.experimental.tiering_service.gcp_storage_client import OperationStatus +from orbax.checkpoint.experimental.tiering_service.gcp_storage_client import TransferContext +from orbax.checkpoint.experimental.tiering_service.proto import tiering_service_pb2 +import sqlalchemy +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +import sqlalchemy.orm + + +class TieringServiceWorker: + """Background worker that processes AssetJobs.""" + + def __init__( + self, + session_maker: Any, + config: tiering_service_pb2.ServerConfig, + gcp_client: GCPStorageClient | None = None, + *, + lease_duration_seconds: int = 60, + poll_interval_seconds: int = 5, + ): + """Initializes the background worker. + + Args: + session_maker: A callable that returns a database session. + config: The server configuration. + gcp_client: Client to interact with GCP Parallelstore. + lease_duration_seconds: Duration of the lease acquired for jobs. + poll_interval_seconds: Polling interval for checking job status. + """ + self._session_maker = session_maker + self._config = config + self._gcp_client = gcp_client + self._lease_duration = datetime.timedelta(seconds=lease_duration_seconds) + self._poll_interval = poll_interval_seconds + self._hostname = socket.gethostname() + self._pid = os.getpid() + self._tasks = [] + self._shutdown_event = asyncio.Event() + + def _get_client_for_backends( + self, + source_backend: db_schema.StorageBackend, + target_backend: db_schema.StorageBackend, + ) -> GCPStorageClient: + """Returns the appropriate GCP client for the given transfer backends.""" + if self._gcp_client is not None: + return self._gcp_client + + project = self._config.gcp_project or None + service_account = self._config.service_account or None + + if ( + source_backend.backend_type == db_schema.BackendType.BACKEND_TYPE_GCS + and target_backend.backend_type + == db_schema.BackendType.BACKEND_TYPE_GCS + ): + return GcsToGcsClient(project=project, service_account=service_account) + + if ( + source_backend.backend_type == db_schema.BackendType.BACKEND_TYPE_LUSTRE + and target_backend.backend_type + == db_schema.BackendType.BACKEND_TYPE_GCS + ): + location = source_backend.zone + instance = f"lustre-{location}" + return LustreToGcsClient( + instance=instance, + location=location, + project=project, + service_account=service_account, + ) + + if ( + source_backend.backend_type == db_schema.BackendType.BACKEND_TYPE_GCS + and target_backend.backend_type + == db_schema.BackendType.BACKEND_TYPE_LUSTRE + ): + location = target_backend.zone + instance = f"lustre-{location}" + return GcsToLustreClient( + instance=instance, + location=location, + project=project, + service_account=service_account, + ) + + raise ValueError( + f"Unsupported backend pair: {source_backend.backend_type} and" + f" {target_backend.backend_type}" + ) + + def _get_client_for_job(self, job: db_schema.AssetJob) -> GCPStorageClient: + """Returns the client for the given job.""" + if self._gcp_client is not None: + return self._gcp_client + + target_tp = job.target_tier_path + asset = job.asset + source_tp = None + for tp in asset.tier_paths: + if tp.id != target_tp.id: + source_tp = tp + break + + if not source_tp or not target_tp: + raise ValueError( + f"Could not resolve source and target backends for job {job.id}" + ) + + return self._get_client_for_backends( + source_tp.storage_backend, target_tp.storage_backend + ) + + async def start(self): + """Starts the background worker loops.""" + logging.info( + "Starting TieringServiceWorker on %s:%d", self._hostname, self._pid + ) + self._shutdown_event.clear() + self._tasks.append(asyncio.create_task(self._run_acquisition_loop())) + self._tasks.append(asyncio.create_task(self._run_polling_loop())) + + async def stop(self): + """Stops the background worker loops gracefully.""" + logging.info("Stopping TieringServiceWorker...") + self._shutdown_event.set() + if self._tasks: + # Wait for tasks to cancel and finish + for task in self._tasks: + task.cancel() + await asyncio.gather(*self._tasks, return_exceptions=True) + self._tasks.clear() + logging.info("TieringServiceWorker stopped.") + + async def _run_acquisition_loop(self): + """Periodically acquires and triggers queued jobs.""" + while not self._shutdown_event.is_set(): + try: + async with self._session_maker() as session: + async with session.begin(): + job = await self._acquire_next_job(session) + if job: + logging.info("Acquired job %d", job.id) + await self._process_job(session, job) + except Exception: # pylint: disable=broad-except + logging.exception("Error in job acquisition loop.") + await asyncio.sleep(self._poll_interval) + + async def _run_polling_loop(self): + """Periodically polls status of active jobs owned by this worker.""" + while not self._shutdown_event.is_set(): + try: + async with self._session_maker() as session: + async with session.begin(): + await self._poll_active_jobs(session) + except Exception: # pylint: disable=broad-except + logging.exception("Error in job polling loop.") + await asyncio.sleep(self._poll_interval) + + async def _acquire_next_job( + self, session: AsyncSession + ) -> db_schema.AssetJob | None: + """Queries the database for the next eligible job and claims it.""" + now = datetime.datetime.now(datetime.timezone.utc) + max_active = self._config.max_active_jobs_per_backend + + # 1. Identify assets currently executing active transfers (PROCESSING and + # not expired) + active_assets_subquery = ( + select(db_schema.AssetJob.asset_uuid) + .where( + db_schema.AssetJob.status + == db_schema.JobStatus.JOB_STATUS_PROCESSING, + db_schema.AssetJob.expiration_at >= now, + ) + .scalar_subquery() + ) + + # 2. Identify storage backends that have reached their maximum active + # job limit (N) + busy_backends_subquery = ( + select(db_schema.TierPath.storage_backend_id) + .join( + db_schema.AssetJob, + db_schema.AssetJob.target_tier_path_id == db_schema.TierPath.id, + ) + .where( + db_schema.AssetJob.status + == db_schema.JobStatus.JOB_STATUS_PROCESSING, + db_schema.AssetJob.expiration_at >= now, + ) + .group_by(db_schema.TierPath.storage_backend_id) + .having(sqlalchemy.func.count(db_schema.AssetJob.id) >= max_active) + .scalar_subquery() + ) + + # 3. Fetch the oldest eligible job (QUEUED or expired PROCESSING) + stmt = ( + select(db_schema.AssetJob) + .options( + sqlalchemy.orm.selectinload( + db_schema.AssetJob.target_tier_path + ).selectinload(db_schema.TierPath.storage_backend), + sqlalchemy.orm.selectinload(db_schema.AssetJob.asset) + .selectinload(db_schema.Asset.tier_paths) + .selectinload(db_schema.TierPath.storage_backend), + ) + .join( + db_schema.TierPath, + db_schema.AssetJob.target_tier_path_id == db_schema.TierPath.id, + isouter=True, + ) + .where( + sqlalchemy.or_( + db_schema.AssetJob.status + == db_schema.JobStatus.JOB_STATUS_QUEUED, + sqlalchemy.and_( + db_schema.AssetJob.status + == db_schema.JobStatus.JOB_STATUS_PROCESSING, + db_schema.AssetJob.expiration_at < now, + ), + ), + ~db_schema.AssetJob.asset_uuid.in_(active_assets_subquery), + sqlalchemy.or_( + db_schema.AssetJob.target_tier_path_id.is_(None), + ~db_schema.TierPath.storage_backend_id.in_( + busy_backends_subquery + ), + ), + ) + .order_by(db_schema.AssetJob.created_at.asc()) + .limit(1) + .with_for_update(skip_locked=True) + ) + + result = await session.execute(stmt) + job = result.scalars().first() + + if job: + # Atomically claim the job + job.status = db_schema.JobStatus.JOB_STATUS_PROCESSING + job.expiration_at = now + self._lease_duration + job.worker_host = self._hostname + job.worker_pid = self._pid + job.last_updated_at = now + session.add(job) + if ( + job.request_type == db_schema.RequestType.REQUEST_TYPE_COPY + and job.target_tier_path + ): + job.target_tier_path.state = ( + db_schema.TierPathState.IN_PROGRESS + ) + session.add(job.target_tier_path) + + return job + + async def _process_job(self, session: AsyncSession, job: db_schema.AssetJob): + """Triggers the GCP transfer for the acquired job.""" + # We need to fetch the source and destination paths. + # For COPY (Prefetch): Source is GCS (Tier 1), Destination is Lustre + # (Tier 0) + # For DELETE_FROM_INSTANCE: We just delete from Lustre. + # For DELETE_FROM_ALL_TIERS: We delete from all. + # + # Actually, jobs.py only handles data movement (COPY/Export). + # Deletion jobs might be handled differently, but let's see. + # "manages the asynchronous data movement (Lustre import/export)" + # + # If request_type is COPY: + # We need to find the GCS path (source) and Lustre path (target). + # The job has `target_tier_path_id` which is the Lustre path we are + # copying to. + # We need to find the GCS path for the same asset. + + if job.request_type == db_schema.RequestType.REQUEST_TYPE_COPY: + await self._process_copy_job(session, job) + else: + # For now, we only handle COPY in the worker. Other request types + # (e.g. DELETE) are not implemented. + logging.warning("Unsupported job request type: %s", job.request_type) + # Mark as failed for now if not COPY + await self._fail_job( + session, + job, + f"Unsupported request type: {job.request_type}", + {"status": "FAILED"}, + ) + + async def _process_copy_job( + self, session: AsyncSession, job: db_schema.AssetJob + ): + """Processes a COPY job (GCS -> Lustre or Lustre -> GCS).""" + # Find the target TierPath + target_tp = job.target_tier_path + if not target_tp: + await self._fail_job( + session, + job, + "Target TierPath not found", + {"status": "FAILED"}, + ) + return + + asset = job.asset + # Find the source TierPath (the other tier path that is ready) + # TODO(dnlng): just find a source that is closest to the target. + source_tp = None + for tp in asset.tier_paths: + if tp.id != target_tp.id and tp.ready_at is not None: + source_tp = tp + break + + if not source_tp: + await self._fail_job( + session, + job, + "Source TierPath not found or not ready", + {"status": "FAILED"}, + ) + return + + # Determine if import or export + # If source is GCS (level 1) and target is Lustre (level 0) -> Import + # If source is Lustre (level 0) and target is GCS (level 1) -> Export + try: + async with self._get_client_for_backends( + source_tp.storage_backend, target_tp.storage_backend + ) as client: + operation_name = await client.trigger_copy( + job.request_id, source_tp.path, target_tp.path + ) + + job.transfer_status = { + "request_id": operation_name, + "status": OperationStatus.IN_PROGRESS.value, + "bytes_copied": 0, + "total_bytes": 0, + } + session.add(job) + logging.info( + "Triggered transfer for job %d, operation_name: %s", + job.id, + operation_name, + ) + + except Exception as e: # pylint: disable=broad-except + logging.exception("Failed to trigger transfer for job %d", job.id) + await self._fail_job( + session, + job, + f"Failed to trigger transfer: {e}", + {"status": "FAILED"}, + ) + + async def _poll_active_jobs(self, session: AsyncSession): + """Polls status of active jobs owned by this worker.""" + now = datetime.datetime.now(datetime.timezone.utc) + stmt = ( + select(db_schema.AssetJob) + .options( + sqlalchemy.orm.selectinload( + db_schema.AssetJob.target_tier_path + ).selectinload(db_schema.TierPath.storage_backend), + sqlalchemy.orm.selectinload(db_schema.AssetJob.asset) + .selectinload(db_schema.Asset.tier_paths) + .selectinload(db_schema.TierPath.storage_backend), + ) + .where( + db_schema.AssetJob.status + == db_schema.JobStatus.JOB_STATUS_PROCESSING, + db_schema.AssetJob.worker_host == self._hostname, + db_schema.AssetJob.worker_pid == self._pid, + ) + ) + result = await session.execute(stmt) + active_jobs = result.scalars().all() + + for job in active_jobs: + logging.info("Polling job %d", job.id) + await self._extend_lease(session, job, now) + + status_dict = job.transfer_status + if not status_dict or "request_id" not in status_dict: + logging.warning("Job %d in PROCESSING but has no request_id", job.id) + continue + + req_id = status_dict["request_id"] + try: + async with self._get_client_for_job(job) as client: + target_tp = job.target_tier_path + source_tp = None + for tp in job.asset.tier_paths: + if tp.id != target_tp.id and tp.ready_at is not None: + source_tp = tp + break + if not source_tp: + raise ValueError(f"Source TierPath not found for job {job.id}") + + context = TransferContext( + job_request_id=job.request_id, + source_path=source_tp.path, + destination_path=target_tp.path, + transfer_status=status_dict, + ) + gcp_result = await client.poll_operation(req_id, context=context) + logging.info( + "Job %d GCP status: %s, detail_info: %s", + job.id, + gcp_result.status, + gcp_result.detail_info, + ) + + # Update transfer_status JSON + new_status = { + **status_dict, + "status": gcp_result.status.value, + **gcp_result.detail_info, + } + job.transfer_status = new_status + job.last_updated_at = now + + if gcp_result.status == OperationStatus.SUCCESS: + await self._complete_job(session, job, now) + elif gcp_result.status == OperationStatus.FAILED: + error_msg = gcp_result.detail_info.get("error", "Unknown GCP error") + await self._fail_job(session, job, error_msg, new_status, now) + else: + session.add(job) + + except Exception: # pylint: disable=broad-except + logging.exception("Error polling job %d", job.id) + # Do not fail immediately on poll error, lease will expire if worker + # dies. + # But we can update last_updated_at to show we tried. + job.last_updated_at = now + session.add(job) + + async def _extend_lease( + self, + session: AsyncSession, + job: db_schema.AssetJob, + now: datetime.datetime, + ): + """Extends the lease of the job (heartbeat).""" + job.expiration_at = now + self._lease_duration + session.add(job) + logging.info("Extended lease for job %d to %s", job.id, job.expiration_at) + + async def _complete_job( + self, + session: AsyncSession, + job: db_schema.AssetJob, + now: datetime.datetime, + ): + """Marks the job as completed and updates target TierPath.""" + job.status = db_schema.JobStatus.JOB_STATUS_COMPLETED + job.completed_at = now + job.worker_host = None + job.worker_pid = None + job.expiration_at = None + session.add(job) + + # Mark target TierPath as ready + target_tp = job.target_tier_path + if target_tp: + target_tp.state = db_schema.TierPathState.READY + target_tp.ready_at = now + # Calculate expiration for TierPath if it is Level 0 (Lustre) + # "the checkpoint could be removed from this location after it expires." + # GCS (Level 1) paths usually don't expire. + if ( + target_tp.storage_backend.backend_type + == db_schema.BackendType.BACKEND_TYPE_LUSTRE + ): + # Set expires_at. We can use a default TTL or parse it from config. + # For now, let's use a default of 1 hour, or use keep_alive_interval. + # Actually, server has client_keep_alive_interval_seconds + # (default 1800). + # Let's use that as a base. + ttl = datetime.timedelta( + seconds=self._config.client_keep_alive_interval_seconds + ) + target_tp.expires_at = assets.calculate_expires_at(ttl) + session.add(target_tp) + + logging.info( + "Completed job %d, target TierPath %s marked ready", + job.id, + target_tp.path if target_tp else "None", + ) + + async def _fail_job( + self, + session: AsyncSession, + job: db_schema.AssetJob, + error_msg: str, + transfer_status: dict[str, Any], + now: datetime.datetime | None = None, + ): + """Marks the job as failed and cleans up target TierPath.""" + if not now: + now = datetime.datetime.now(datetime.timezone.utc) + job.status = db_schema.JobStatus.JOB_STATUS_FAILED + job.completed_at = now + job.worker_host = None + job.worker_pid = None + job.expiration_at = None + transfer_status["error"] = error_msg + job.transfer_status = transfer_status + session.add(job) + + # Clean up target TierPath on failure (set state to FAILED) + target_tp = job.target_tier_path + if target_tp: + target_tp.state = db_schema.TierPathState.FAILED + session.add(target_tp) + + logging.error("Failed job %d: %s", job.id, error_msg) + + +async def run_tiering_service_worker_loop( + session_maker: Any, + config: tiering_service_pb2.ServerConfig, + gcp_client: GCPStorageClient | None = None, + *, + lease_duration_seconds: int = 60, + poll_interval_seconds: int = 5, +) -> TieringServiceWorker: + """Runs the worker loop.""" + worker = TieringServiceWorker( + session_maker, + config, + gcp_client, + lease_duration_seconds=lease_duration_seconds, + poll_interval_seconds=poll_interval_seconds, + ) + await worker.start() + return worker diff --git a/checkpoint/orbax/checkpoint/experimental/tiering_service/job_worker_test.py b/checkpoint/orbax/checkpoint/experimental/tiering_service/job_worker_test.py new file mode 100644 index 000000000..71011cdf2 --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/tiering_service/job_worker_test.py @@ -0,0 +1,537 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import datetime +import unittest +from unittest import mock + +from absl import logging +from absl.testing import absltest + +from orbax.checkpoint.experimental.tiering_service import db_lib +from orbax.checkpoint.experimental.tiering_service import db_schema +from orbax.checkpoint.experimental.tiering_service import gcp_storage_client +from orbax.checkpoint.experimental.tiering_service import job_worker +from orbax.checkpoint.experimental.tiering_service import server_config +from orbax.checkpoint.experimental.tiering_service.proto import tiering_service_pb2 +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +from sqlalchemy.orm import sessionmaker + + +class DummyGcpParallelstoreClient(gcp_storage_client.GCPStorageClient): + """Dummy implementation of GCPStorageClient for testing.""" + + def __init__(self): + """Initializes the dummy client with empty operations list.""" + super().__init__() + self.operations = {} + + async def trigger_copy( + self, + request_id: str, + source_path: str, + destination_path: str, + ) -> str: + """Triggers copy in progress.""" + self.operations[request_id] = { + "status": gcp_storage_client.OperationStatus.IN_PROGRESS, + "progress": 0, + "type": "copy", + } + logging.info( + "Dummy triggered copy %s -> %s, request_id: %s", + source_path, + destination_path, + request_id, + ) + return request_id + + async def poll_operation( + self, + request_id: str, + context: gcp_storage_client.TransferContext | None = None, + ) -> gcp_storage_client.Result: + """Polls the status of the specified GCP operation.""" + op = self.operations.get(request_id) + if not op: + return gcp_storage_client.Result( + status=gcp_storage_client.OperationStatus.FAILED, + detail_info={"error": "Operation not found"}, + ) + + if op["status"] == gcp_storage_client.OperationStatus.IN_PROGRESS: + op["progress"] += 50 # Progress by 50% each poll + if op["progress"] >= 100: + op["status"] = gcp_storage_client.OperationStatus.SUCCESS + + return gcp_storage_client.Result( + status=op["status"], + detail_info={ + "bytes_copied": op["progress"] * 1000, + "total_bytes": 100000, + }, + ) + + +class TieringServiceWorkerTest( + absltest.TestCase, unittest.IsolatedAsyncioTestCase +): + + async def asyncSetUp(self): + await super().asyncSetUp() + storage_backends_config = [ + { + "level": 0, + "backend_type": "BACKEND_TYPE_LUSTRE", + "prefix": "/mnt/lustre-a", + "zone": "us-central1-a", + }, + { + "level": 0, + "backend_type": "BACKEND_TYPE_LUSTRE", + "prefix": "/mnt/lustre-b", + "zone": "us-central1-b", + }, + { + "level": 1, + "backend_type": "BACKEND_TYPE_GCS", + "prefix": "gs://my-bucket", + "region": "us-central1", + }, + ] + self.config = server_config.parse_config({ + "storage_backends": storage_backends_config, + "max_active_jobs_per_backend": 1, + }) + # Use temp file SQLite for testing + self.tmp_file = self.create_tempfile() + self.config.db_connection_str = ( + f"sqlite+aiosqlite:///{self.tmp_file.full_path}" + ) + + await db_lib.async_initialize_db(self.config) + self.engine = db_lib.get_async_engine(self.config) + self.session_maker = sessionmaker( + self.engine, expire_on_commit=False, class_=AsyncSession + ) + self.gcp_client = DummyGcpParallelstoreClient() + # Short poll interval for fast tests + self.worker = job_worker.TieringServiceWorker( + self.session_maker, + self.config, + gcp_client=self.gcp_client, + lease_duration_seconds=2, # Short lease for testing expiration + poll_interval_seconds=1, + ) + + async def asyncTearDown(self): + await self.worker.stop() + await self.engine.dispose() + await super().asyncTearDown() + + async def _create_asset_and_job( + self, + session, + *, + asset_uuid, + path, + target_backend_id, + source_backend_id, + job_status=db_schema.JobStatus.JOB_STATUS_QUEUED, + ): + """Helper to create an asset, its tier paths, and a prefetch job. + + Args: + session: The database session. + asset_uuid: Unique identifier of the asset. + path: Relative path of the asset. + target_backend_id: The ID of the target storage backend. + source_backend_id: The ID of the source storage backend. + job_status: The initial status of the prefetch job. + + Returns: + A tuple of (AssetJob, TierPath) created. + """ + del self + asset = db_schema.Asset( + asset_uuid=asset_uuid, + path=path, + user="test-user", + state=db_schema.AssetState.ASSET_STATE_STORED, + ) + session.add(asset) + + # Source TierPath (ready) + source_tp = db_schema.TierPath( + asset_uuid=asset_uuid, + storage_backend_id=source_backend_id, + path=f"gs://my-bucket/{path}", + ready_at=datetime.datetime(2020, 1, 1, tzinfo=datetime.timezone.utc), + ) + session.add(source_tp) + + # Target TierPath (pending) + target_tp = db_schema.TierPath( + asset_uuid=asset_uuid, + storage_backend_id=target_backend_id, + path=f"/mnt/lustre/{path}", + ) + session.add(target_tp) + await session.flush() + + job = db_schema.AssetJob( + asset_uuid=asset_uuid, + request_type=db_schema.RequestType.REQUEST_TYPE_COPY, + status=job_status, + target_tier_path_id=target_tp.id, + ) + session.add(job) + await session.commit() + return job, target_tp + + async def test_job_acquisition_success(self): + async with self.session_maker() as session: + # Get backend IDs + result = await session.execute(select(db_schema.StorageBackend)) + backends = result.scalars().all() + lustre_a = next(b for b in backends if b.zone == "us-central1-a") + gcs = next(b for b in backends if b.region == "us-central1") + + await self._create_asset_and_job( + session, + asset_uuid="asset-1", + path="path/1", + target_backend_id=lustre_a.id, + source_backend_id=gcs.id, + ) + + # Start worker to process the job + await self.worker.start() + + # Wait for job to complete with timeout + for _ in range(10): + await asyncio.sleep(1) + async with self.session_maker() as session: + result = await session.execute(select(db_schema.AssetJob)) + job = result.scalars().first() + if job.status == db_schema.JobStatus.JOB_STATUS_COMPLETED: + break + + await self.worker.stop() + + async with self.session_maker() as session: + # Verify job status transitions to COMPLETED (since dummy client + # progresses fast) + result = await session.execute(select(db_schema.AssetJob)) + job = result.scalars().first() + self.assertEqual(job.status, db_schema.JobStatus.JOB_STATUS_COMPLETED) + self.assertIsNotNone(job.completed_at) + self.assertIsNone(job.worker_host) + self.assertIsNone(job.worker_pid) + + # Verify target TierPath is ready + result_tp = await session.execute( + select(db_schema.TierPath).where( + db_schema.TierPath.asset_uuid == "asset-1" + ) + ) + tps = result_tp.scalars().all() + target_tp = next(tp for tp in tps if "lustre" in tp.path) + self.assertEqual( + target_tp.state, db_schema.TierPathState.READY + ) + self.assertIsNotNone(target_tp.ready_at) + self.assertIsNotNone(target_tp.expires_at) + + async def test_job_failure_clean_up_target_tier_path(self): + async with self.session_maker() as session: + result = await session.execute(select(db_schema.StorageBackend)) + backends = result.scalars().all() + lustre_a = next(b for b in backends if b.zone == "us-central1-a") + gcs = next(b for b in backends if b.region == "us-central1") + + await self._create_asset_and_job( + session, + asset_uuid="asset-1", + path="path/1", + target_backend_id=lustre_a.id, + source_backend_id=gcs.id, + ) + + # Mock poll_operation to return a failure + with mock.patch.object( + self.gcp_client, + "poll_operation", + autospec=True, + return_value=gcp_storage_client.Result( + status=gcp_storage_client.OperationStatus.FAILED, + detail_info={"error": "Mocked GCP error"}, + ), + ): + await self.worker.start() + # Wait for job to fail + for _ in range(10): + await asyncio.sleep(1) + async with self.session_maker() as session: + result = await session.execute(select(db_schema.AssetJob)) + job = result.scalars().first() + if job.status == db_schema.JobStatus.JOB_STATUS_FAILED: + break + await self.worker.stop() + + async with self.session_maker() as session: + # Verify job status is FAILED, completed_at is set, and + # target_tier_path_id is preserved + result = await session.execute(select(db_schema.AssetJob)) + job = result.scalars().first() + self.assertEqual(job.status, db_schema.JobStatus.JOB_STATUS_FAILED) + self.assertIsNotNone(job.completed_at) + self.assertIsNotNone(job.target_tier_path_id) + self.assertIsNone(job.worker_host) + self.assertIsNone(job.worker_pid) + + # Verify target TierPath is preserved but marked FAILED + result_tp = await session.execute( + select(db_schema.TierPath).where( + db_schema.TierPath.asset_uuid == "asset-1" + ) + ) + tps = result_tp.scalars().all() + # Both source (GCS) and target (Lustre) TierPaths should exist + self.assertLen(tps, 2) + dest_tp = next(tp for tp in tps if "my-bucket" not in tp.path) + self.assertEqual( + dest_tp.state, db_schema.TierPathState.FAILED + ) + + async def test_concurrency_limit_respected(self): + async with self.session_maker() as session: + result = await session.execute(select(db_schema.StorageBackend)) + backends = result.scalars().all() + lustre_a = next(b for b in backends if b.zone == "us-central1-a") + gcs = next(b for b in backends if b.region == "us-central1") + + # Create 2 jobs targeting the SAME backend (Lustre A) + await self._create_asset_and_job( + session, + asset_uuid="asset-1", + path="path/1", + target_backend_id=lustre_a.id, + source_backend_id=gcs.id, + ) + await self._create_asset_and_job( + session, + asset_uuid="asset-2", + path="path/2", + target_backend_id=lustre_a.id, + source_backend_id=gcs.id, + ) + + # Dummy GCP client to keep operations RUNNING so we can check concurrency + with mock.patch.object( + self.gcp_client, + "poll_operation", + autospec=True, + return_value=gcp_storage_client.Result( + status=gcp_storage_client.OperationStatus.IN_PROGRESS, + detail_info={"bytes_copied": 500}, + ), + ): + await self.worker.start() + await asyncio.sleep(2) + await self.worker.stop() + + async with self.session_maker() as session: + result = await session.execute( + select(db_schema.AssetJob).order_by(db_schema.AssetJob.id) + ) + jobs_list = result.scalars().all() + + # One should be PROCESSING, the other still QUEUED + self.assertEqual( + jobs_list[0].status, db_schema.JobStatus.JOB_STATUS_PROCESSING + ) + self.assertEqual( + jobs_list[1].status, db_schema.JobStatus.JOB_STATUS_QUEUED + ) + + async def test_different_backends_concurrency(self): + async with self.session_maker() as session: + result = await session.execute(select(db_schema.StorageBackend)) + backends = result.scalars().all() + lustre_a = next(b for b in backends if b.zone == "us-central1-a") + lustre_b = next(b for b in backends if b.zone == "us-central1-b") + gcs = next(b for b in backends if b.region == "us-central1") + + # Create 2 jobs targeting DIFFERENT backends (Lustre A and Lustre B) + await self._create_asset_and_job( + session, + asset_uuid="asset-1", + path="path/1", + target_backend_id=lustre_a.id, + source_backend_id=gcs.id, + ) + await self._create_asset_and_job( + session, + asset_uuid="asset-2", + path="path/2", + target_backend_id=lustre_b.id, + source_backend_id=gcs.id, + ) + + with mock.patch.object( + self.gcp_client, + "poll_operation", + autospec=True, + return_value=gcp_storage_client.Result( + status=gcp_storage_client.OperationStatus.IN_PROGRESS, + detail_info={"bytes_copied": 500}, + ), + ): + await self.worker.start() + await asyncio.sleep(2) + await self.worker.stop() + + async with self.session_maker() as session: + result = await session.execute(select(db_schema.AssetJob)) + jobs_list = result.scalars().all() + + # Both should be PROCESSING because they target different backends + self.assertEqual( + jobs_list[0].status, db_schema.JobStatus.JOB_STATUS_PROCESSING + ) + self.assertEqual( + jobs_list[1].status, db_schema.JobStatus.JOB_STATUS_PROCESSING + ) + + async def test_crash_recovery_on_lease_expiration(self): + async with self.session_maker() as session: + result = await session.execute(select(db_schema.StorageBackend)) + backends = result.scalars().all() + lustre_a = next(b for b in backends if b.zone == "us-central1-a") + gcs = next(b for b in backends if b.region == "us-central1") + + # Create a job that is already in PROCESSING state but has an + # expired lease + job, _ = await self._create_asset_and_job( + session, + asset_uuid="asset-1", + path="path/1", + target_backend_id=lustre_a.id, + source_backend_id=gcs.id, + job_status=db_schema.JobStatus.JOB_STATUS_PROCESSING, + ) + # Set expired lease + job.expiration_at = datetime.datetime( + 2020, 1, 1, tzinfo=datetime.timezone.utc + ) + job.worker_host = "dead-host" + job.worker_pid = 9999 + + # Populate dummy operation in client + op_id = "operations/import-dummy-id" + self.gcp_client.operations[op_id] = { + "status": gcp_storage_client.OperationStatus.IN_PROGRESS, + "progress": 0, + "type": "import", + } + job.transfer_status = { + "request_id": op_id, + "status": gcp_storage_client.OperationStatus.IN_PROGRESS.value, + } + await session.commit() + + # Start worker + await self.worker.start() + + # Wait for job to complete with timeout + for _ in range(10): + await asyncio.sleep(1) + async with self.session_maker() as session: + result = await session.execute(select(db_schema.AssetJob)) + recovered_job = result.scalars().first() + if recovered_job.status == db_schema.JobStatus.JOB_STATUS_COMPLETED: + break + + await self.worker.stop() + + async with self.session_maker() as session: + result = await session.execute(select(db_schema.AssetJob)) + recovered_job = result.scalars().first() + + # The worker should have reclaimed it and eventually completed it + # (via dummy client) + self.assertEqual( + recovered_job.status, db_schema.JobStatus.JOB_STATUS_COMPLETED + ) + self.assertIsNone(recovered_job.worker_host) + + +class DynamicClientResolutionTest(unittest.IsolatedAsyncioTestCase): + + def setUp(self): + super().setUp() + self.gcs_backend = db_schema.StorageBackend( + level=1, + region="us-central1", + backend_type=db_schema.BackendType.BACKEND_TYPE_GCS, + prefix="gs://my-bucket", + ) + self.lustre_backend = db_schema.StorageBackend( + level=0, + zone="us-central1-a", + backend_type=db_schema.BackendType.BACKEND_TYPE_LUSTRE, + prefix="/mnt/lustre", + ) + + def test_resolves_gcs_to_lustre_client(self): + worker = job_worker.TieringServiceWorker( + session_maker=None, + config=tiering_service_pb2.ServerConfig(), + gcp_client=None, + ) + client = worker._get_client_for_backends( + self.gcs_backend, self.lustre_backend + ) + self.assertIsInstance(client, job_worker.GcsToLustreClient) + self.assertEqual(client.location, "us-central1-a") + self.assertEqual(client.instance, "lustre-us-central1-a") + + def test_resolves_lustre_to_gcs_client(self): + worker = job_worker.TieringServiceWorker( + session_maker=None, + config=tiering_service_pb2.ServerConfig(), + gcp_client=None, + ) + client = worker._get_client_for_backends( + self.lustre_backend, self.gcs_backend + ) + self.assertIsInstance(client, job_worker.LustreToGcsClient) + self.assertEqual(client.location, "us-central1-a") + self.assertEqual(client.instance, "lustre-us-central1-a") + + def test_resolves_gcs_to_gcs_client(self): + worker = job_worker.TieringServiceWorker( + session_maker=None, + config=tiering_service_pb2.ServerConfig(), + gcp_client=None, + ) + client = worker._get_client_for_backends(self.gcs_backend, self.gcs_backend) + self.assertIsInstance(client, job_worker.GcsToGcsClient) + + +if __name__ == "__main__": + absltest.main() diff --git a/checkpoint/orbax/checkpoint/experimental/tiering_service/proto/tiering_service.proto b/checkpoint/orbax/checkpoint/experimental/tiering_service/proto/tiering_service.proto index eab98a32b..8ee7737e2 100644 --- a/checkpoint/orbax/checkpoint/experimental/tiering_service/proto/tiering_service.proto +++ b/checkpoint/orbax/checkpoint/experimental/tiering_service/proto/tiering_service.proto @@ -150,6 +150,9 @@ message ServerConfig { int64 client_keep_alive_interval_seconds = 1 [default = 1800]; string db_connection_str = 2 [default = "sqlite+aiosqlite:///:memory:"]; repeated StorageBackend storage_backends = 3; + int64 max_active_jobs_per_backend = 4 [default = 1]; + string gcp_project = 5; + string service_account = 6; } service TieringService { diff --git a/checkpoint/orbax/checkpoint/experimental/tiering_service/server.py b/checkpoint/orbax/checkpoint/experimental/tiering_service/server.py index e39b19802..c63235f68 100644 --- a/checkpoint/orbax/checkpoint/experimental/tiering_service/server.py +++ b/checkpoint/orbax/checkpoint/experimental/tiering_service/server.py @@ -30,6 +30,7 @@ from orbax.checkpoint.experimental.tiering_service import auth from orbax.checkpoint.experimental.tiering_service import db_lib from orbax.checkpoint.experimental.tiering_service import db_schema +from orbax.checkpoint.experimental.tiering_service import job_worker from orbax.checkpoint.experimental.tiering_service import server_config from orbax.checkpoint.experimental.tiering_service import storage_backend from orbax.checkpoint.experimental.tiering_service.proto import tiering_service_pb2 @@ -68,6 +69,11 @@ def __init__(self, config: tiering_service_pb2.ServerConfig): ) self._level0_backends: Sequence[db_schema.StorageBackend] | None = None + @property + def session_maker(self) -> sessionmaker: + """The session maker for the database.""" + return self._session_maker + async def initialize(self) -> None: """Initializes the servicer, loading static data.""" async with self._session_maker() as session: @@ -541,17 +547,21 @@ async def setup_storage_backends( class CtsServer: """Checkpoint Tiering Service (CTS) Server CLI.""" - async def serve(self, yaml_path: str) -> None: + async def serve( + self, yaml_path: str, start_tiering_service_worker: bool = False + ) -> None: """Starts the gRPC server. Args: yaml_path: Path to the YAML configuration file. + start_tiering_service_worker: Whether to start the tiering service worker. """ config = server_config.load_config(yaml_path) await setup_storage_backends(config) server = grpc.aio.server(futures.ThreadPoolExecutor(max_workers=10)) servicer = TieringServiceServicer(config) + worker = None try: await servicer.initialize() tiering_service_pb2_grpc.add_TieringServiceServicer_to_server( @@ -563,11 +573,17 @@ async def serve(self, yaml_path: str) -> None: server.add_secure_port("[::]:50051", server_creds) await server.start() - # TODO: b/503445463 - Start background garbage collection task to handle - # expired assets. + # Start background worker + if start_tiering_service_worker: + worker = await job_worker.run_tiering_service_worker_loop( + servicer.session_maker, config + ) await server.wait_for_termination() finally: + await server.stop(grace=0) + if worker: + await worker.stop() await servicer.close() diff --git a/checkpoint/orbax/checkpoint/experimental/tiering_service/server_config.py b/checkpoint/orbax/checkpoint/experimental/tiering_service/server_config.py index c8755064a..224018f94 100644 --- a/checkpoint/orbax/checkpoint/experimental/tiering_service/server_config.py +++ b/checkpoint/orbax/checkpoint/experimental/tiering_service/server_config.py @@ -124,6 +124,32 @@ def _parse_storage_backends( _parse_storage_backend(b_data, backend) +def _parse_max_active_jobs_per_backend( + data: Mapping[str, Any], config: tiering_service_pb2.ServerConfig +) -> None: + """Parses max active jobs per backend into ServerConfig.""" + if "max_active_jobs_per_backend" in data: + config.max_active_jobs_per_backend = int( + data["max_active_jobs_per_backend"] + ) + + +def _parse_gcp_project( + data: Mapping[str, Any], config: tiering_service_pb2.ServerConfig +) -> None: + """Parses gcp_project into ServerConfig.""" + if "gcp_project" in data and data["gcp_project"] is not None: + config.gcp_project = str(data["gcp_project"]) + + +def _parse_service_account( + data: Mapping[str, Any], config: tiering_service_pb2.ServerConfig +) -> None: + """Parses service_account into ServerConfig.""" + if "service_account" in data and data["service_account"] is not None: + config.service_account = str(data["service_account"]) + + def parse_config(data: Mapping[str, Any]) -> tiering_service_pb2.ServerConfig: """Parses a dictionary into a ServerConfig proto. @@ -138,6 +164,9 @@ def parse_config(data: Mapping[str, Any]) -> tiering_service_pb2.ServerConfig: _parse_client_keep_alive(data, config) _parse_db_connection(data, config) _parse_storage_backends(data, config) + _parse_max_active_jobs_per_backend(data, config) + _parse_gcp_project(data, config) + _parse_service_account(data, config) return config diff --git a/checkpoint/pyproject.toml b/checkpoint/pyproject.toml index f4490abe0..49c3e116d 100644 --- a/checkpoint/pyproject.toml +++ b/checkpoint/pyproject.toml @@ -81,6 +81,7 @@ tiering_service = [ 'fire', 'greenlet', 'grpcio-tools>=1.80.0', + 'httpx', 'pysqlite3', 'pytimeparse', 'sqlalchemy>=1.4.0',