Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions checkpoint/orbax/checkpoint/experimental/tiering_service/db_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,13 @@
"""Database initialization utilities for Tiering Service."""

import contextlib
import sqlite3

from orbax.checkpoint.experimental.tiering_service import db_schema
from orbax.checkpoint.experimental.tiering_service.proto import tiering_service_pb2
from sqlalchemy import event
from sqlalchemy.dialects.sqlite.aiosqlite import AsyncAdapt_aiosqlite_connection
from sqlalchemy.engine import Engine
from sqlalchemy.exc import OperationalError
from sqlalchemy.ext.asyncio import AsyncEngine
from sqlalchemy.ext.asyncio import AsyncSession
Expand All @@ -25,6 +30,29 @@
from sqlalchemy.orm import sessionmaker


@event.listens_for(Engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
"""Enables foreign key constraints on SQLite database connections.

This is SQLite-specific because other databases (like PostgreSQL) enforce
foreign keys by default and do not support SQLite's PRAGMA syntax.
We perform an isinstance check against the standard sqlite3.Connection
and SQLAlchemy's aiosqlite adapter wrapper to verify if this is an SQLite
connection.

Args:
dbapi_connection: The database connection to configure.
connection_record: Metadata about the connection.
"""
del connection_record
connection_types = (sqlite3.Connection, AsyncAdapt_aiosqlite_connection)

if isinstance(dbapi_connection, connection_types):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()


def get_async_engine(config: tiering_service_pb2.ServerConfig) -> AsyncEngine:
"""Returns an AsyncEngine configured from ServerConfig."""
input_url = config.db_connection_str
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,17 @@ class RequestType(enum.IntEnum):
REQUEST_TYPE_DELETE_FROM_ALL_TIERS = 3


class TierPathState(enum.IntEnum):
"""The state of an asset's storage location (tier path)."""

UNSPECIFIED = 0
PENDING = 1
IN_PROGRESS = 2
READY = 3
FAILED = 4
DELETED = 5


class Asset(Base):
"""A CTS asset representing a complete checkpoint.

Expand Down Expand Up @@ -331,26 +342,44 @@ class TierPath(Base):
nullable=False,
default=lambda: str(uuid.uuid4()),
)
state = sqlalchemy.Column(
sqlalchemy.Enum(TierPathState),
default=TierPathState.PENDING,
nullable=False,
)

asset = sqlalchemy.orm.relationship("Asset", back_populates="tier_paths")
storage_backend = sqlalchemy.orm.relationship(
"StorageBackend", back_populates="tier_paths"
)

__table_args__ = (
# An asset can have at most one TierPath for a given storage backend.
sqlalchemy.UniqueConstraint(
# Enforce uniqueness of (asset, backend) only for active tier paths
# (PENDING, IN_PROGRESS, READY).
sqlalchemy.Index(
"uq_tier_path_active_backend",
"asset_uuid",
"storage_backend_id",
name="uq_tier_path_asset_backend",
unique=True,
sqlite_where=sqlalchemy.column("state").in_([
TierPathState.PENDING.name,
TierPathState.IN_PROGRESS.name,
TierPathState.READY.name,
]),
postgresql_where=sqlalchemy.column("state").in_([
TierPathState.PENDING.name,
TierPathState.IN_PROGRESS.name,
TierPathState.READY.name,
]),
),
)

def __repr__(self):
return (
f"TierPath(id={self.id}, asset_uuid='{self.asset_uuid}',"
f" storage_backend_id={self.storage_backend_id}, path='{self.path}',"
f" ready_at={self.ready_at}, expires_at={self.expires_at})"
f" state={self.state.name}, ready_at={self.ready_at},"
f" expires_at={self.expires_at})"
)


Expand All @@ -367,6 +396,13 @@ class AssetJob(Base):
status: Current execution status of the job, an instance of JobStatus.
target_tier_path_id: Foreign key to the targeted TierPath for operations
such as COPY or DELETE_FROM_INSTANCE.
request_id: A unique identifier (UUID) for this job execution request.
transfer_status: JSON dictionary containing progress and GCP operation
details.
expiration_at: Timestamp when the worker's lease on this job expires.
last_updated_at: Timestamp of the last status update or heartbeat.
worker_host: Hostname of the worker processing this job.
worker_pid: Process ID of the worker processing this job.
created_at: Timestamp when the job was created.
completed_at: Timestamp when the job was completed.
asset: Relationship to the associated Asset.
Expand Down Expand Up @@ -396,9 +432,24 @@ class AssetJob(Base):
# Target tier path for COPY and DELETE_FROM_INSTANCE requests
target_tier_path_id = sqlalchemy.Column(
sqlalchemy.Integer,
sqlalchemy.ForeignKey("tier_paths.id", ondelete="CASCADE"),
sqlalchemy.ForeignKey("tier_paths.id"),
nullable=True,
)
request_id = sqlalchemy.Column(
sqlalchemy.String,
nullable=False,
unique=True,
default=lambda: str(uuid.uuid4()),
)
transfer_status = sqlalchemy.Column(sqlalchemy.JSON, nullable=True)
expiration_at = sqlalchemy.Column(
sqlalchemy.DateTime(timezone=True), nullable=True
)
last_updated_at = sqlalchemy.Column(
sqlalchemy.DateTime(timezone=True), nullable=True
)
worker_host = sqlalchemy.Column(sqlalchemy.String, nullable=True)
worker_pid = sqlalchemy.Column(sqlalchemy.Integer, nullable=True)

created_at = sqlalchemy.Column(
sqlalchemy.DateTime(timezone=True),
Expand All @@ -416,9 +467,11 @@ class AssetJob(Base):
# target_tier_path is required in COPY and DELETE_FROM_INSTANCE requests.
sqlalchemy.CheckConstraint(
"""
(request_type IN ('REQUEST_TYPE_COPY', 'REQUEST_TYPE_DELETE_FROM_INSTANCE') AND target_tier_path_id IS NOT NULL)
(request_type IN ('REQUEST_TYPE_COPY', 'REQUEST_TYPE_DELETE_FROM_INSTANCE')
AND target_tier_path_id IS NOT NULL)
OR
(request_type IN ('REQUEST_TYPE_DELETE_FROM_ALL_TIERS', 'REQUEST_TYPE_UNSPECIFIED') AND target_tier_path_id IS NULL)
(request_type IN ('REQUEST_TYPE_DELETE_FROM_ALL_TIERS', 'REQUEST_TYPE_UNSPECIFIED')
AND target_tier_path_id IS NULL)
""",
name="check_asset_job_valid_payload",
),
Expand All @@ -430,5 +483,6 @@ def __repr__(self):
f" request_type='{self.request_type.name}',"
f" status='{self.status.name}',"
f" target_tier_path_id={self.target_tier_path_id},"
f" request_id='{self.request_id}',"
f" created_at={self.created_at}, completed_at={self.completed_at})"
)
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,22 @@
import greenlet # pylint: disable=unused-import
from orbax.checkpoint.experimental.tiering_service import db_schema
import sqlalchemy
from sqlalchemy import event
from sqlalchemy.engine import Engine
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.future import select
from sqlalchemy.orm import sessionmaker


@event.listens_for(Engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
del connection_record
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()


class DbSchemaTest(parameterized.TestCase, unittest.IsolatedAsyncioTestCase):

async def asyncSetUp(self) -> None:
Expand Down Expand Up @@ -442,6 +452,115 @@ async def test_asset_job_queue(self) -> None:
fetched_job.status, db_schema.JobStatus.JOB_STATUS_COMPLETED
)

async def test_tier_path_deletion_fails_when_referenced_by_job(self) -> None:
async with self.session_maker() as session:
asset = db_schema.Asset(
asset_uuid="uuid-delete-tp-fail",
path="/experiment/delete-tp-fail",
user="testuser",
)
backend = db_schema.StorageBackend(
level=0,
zone="us-central1-a",
backend_type=db_schema.BackendType.BACKEND_TYPE_GCS,
prefix="gs://gcs-bucket",
)
tier_path = db_schema.TierPath(
asset_uuid="uuid-delete-tp-fail",
storage_backend=backend,
path="/path1",
)
session.add_all([asset, backend, tier_path])
await session.commit()

job = db_schema.AssetJob(
asset_uuid="uuid-delete-tp-fail",
request_type=db_schema.RequestType.REQUEST_TYPE_COPY,
status=db_schema.JobStatus.JOB_STATUS_COMPLETED,
target_tier_path_id=tier_path.id,
)
session.add(job)
await session.commit()

# Deleting the tier path should fail with IntegrityError because the job
# still references it.
await session.delete(tier_path)
with self.assertRaises(sqlalchemy.exc.IntegrityError):
await session.commit()

async def test_tier_path_conditional_uniqueness(self) -> None:
async with self.session_maker() as session:
asset = db_schema.Asset(
asset_uuid="uuid-cond-uniq",
path="/experiment/cond-uniq",
user="testuser",
)
backend = db_schema.StorageBackend(
level=0,
zone="us-central1-a",
backend_type=db_schema.BackendType.BACKEND_TYPE_GCS,
prefix="gs://gcs-bucket",
)
session.add_all([asset, backend])
await session.commit()

# 1. We can insert one PENDING tier path
tp_pending = db_schema.TierPath(
asset_uuid="uuid-cond-uniq",
storage_backend=backend,
path="/path-pending",
state=db_schema.TierPathState.PENDING,
)
session.add(tp_pending)
await session.commit()

# 2. Trying to insert another IN_PROGRESS tier path should fail
tp_in_progress = db_schema.TierPath(
asset_uuid="uuid-cond-uniq",
storage_backend=backend,
path="/path-inprogress",
state=db_schema.TierPathState.IN_PROGRESS,
)
session.add(tp_in_progress)
with self.assertRaises(sqlalchemy.exc.IntegrityError):
await session.commit()
await session.rollback()

# 3. Transition the first one to FAILED
result = await session.execute(
select(db_schema.TierPath).filter_by(asset_uuid="uuid-cond-uniq")
)
tp = result.scalars().first()
tp.state = db_schema.TierPathState.FAILED
await session.commit()

# 4. Now we can insert a new PENDING tier path for the same backend!
tp_new_pending = db_schema.TierPath(
asset_uuid="uuid-cond-uniq",
storage_backend=backend,
path="/path-new-pending",
state=db_schema.TierPathState.PENDING,
)
session.add(tp_new_pending)
await session.commit()

# 5. And we can also insert a DELETED tier path for the same backend!
tp_deleted = db_schema.TierPath(
asset_uuid="uuid-cond-uniq",
storage_backend=backend,
path="/path-deleted",
state=db_schema.TierPathState.DELETED,
)
session.add(tp_deleted)
await session.commit()

# Verify all 3 rows exist (FAILED, DELETED, PENDING)
result = await session.execute(
select(db_schema.TierPath).filter_by(asset_uuid="uuid-cond-uniq")
)
paths = result.scalars().all()
self.assertLen(paths, 3)

async def test_create_asset_duplicates_allowed_for_deleted_incomplete(self):
# Verify we can have duplicate path for DELETED or INCOMPLETE states
async with self.session_maker() as session:
Expand Down
Loading
Loading