diff --git a/astroml/db/schema.py b/astroml/db/schema.py index caea79b..f4da902 100644 --- a/astroml/db/schema.py +++ b/astroml/db/schema.py @@ -537,9 +537,105 @@ class NormalizedTransaction(Base): ) +# --------------------------------------------------------------------------- +# --------------------------------------------------------------------------- +# Model Registry +# --------------------------------------------------------------------------- + +class Model(Base): + """Machine learning model metadata for the model registry.""" + __tablename__ = "models" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + name: Mapped[str] = mapped_column(String(128), nullable=False, unique=True) + description: Mapped[Optional[str]] = mapped_column(Text) + framework: Mapped[str] = mapped_column(String(32), nullable=False) + task_type: Mapped[str] = mapped_column(String(32), nullable=False) + is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default="true") + created_at: Mapped[datetime] = mapped_column(nullable=False, server_default=func.now()) + updated_at: Mapped[datetime] = mapped_column( + nullable=False, + server_default=func.now(), + onupdate=func.now(), + ) + + versions: Mapped[list["ModelVersion"]] = relationship( + back_populates="model", + cascade="all, delete-orphan", + ) + + __table_args__ = ( + Index("ix_models_framework", "framework"), + Index("ix_models_task_type", "task_type"), + Index("ix_models_is_active", "is_active"), + CheckConstraint( + "framework IN ('pytorch', 'tensorflow', 'sklearn', 'xgboost', 'lightgbm', 'custom')", + name="ck_models_framework", + ), + CheckConstraint( + "task_type IN ('classification', 'regression', 'anomaly_detection', 'clustering', 'custom')", + name="ck_models_task_type", + ), + ) + + +class ModelVersion(Base): + """Specific version of a machine learning model with artifacts and metrics.""" + + __tablename__ = "model_versions" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + model_id: Mapped[int] = mapped_column( + Integer, + ForeignKey("models.id"), + nullable=False, + ) + version: Mapped[str] = mapped_column(String(32), nullable=False) + artifact_path: Mapped[str] = mapped_column(String(512), nullable=False) + hyperparameters: Mapped[Optional[dict]] = mapped_column( + JSON().with_variant(JSONB(), "postgresql") + ) + metrics: Mapped[Optional[dict]] = mapped_column( + JSON().with_variant(JSONB(), "postgresql") + ) + status: Mapped[str] = mapped_column( + String(32), + nullable=False, + server_default="training", + ) + created_at: Mapped[datetime] = mapped_column( + nullable=False, + server_default=func.now(), + ) + updated_at: Mapped[datetime] = mapped_column( + nullable=False, + server_default=func.now(), + onupdate=func.now(), + ) + deployed_at: Mapped[Optional[datetime]] = mapped_column() + + model: Mapped["Model"] = relationship(back_populates="versions") + + __table_args__ = ( + UniqueConstraint( + "model_id", + "version", + name="uq_model_versions_model_version", + ), + Index("ix_model_versions_model_id", "model_id"), + Index("ix_model_versions_status", "status"), + Index("ix_model_versions_created_at", "created_at"), + CheckConstraint( + "status IN ('training', 'trained', 'deployed', 'archived', 'failed')", + name="ck_model_versions_status", + ), + ) + + # --------------------------------------------------------------------------- # Ledger Processing # --------------------------------------------------------------------------- +# --------------------------------------------------------------------------- class ProcessedLedger(Base): """Tracking table for processed ledgers during backfill to ensure idempotency.""" @@ -547,7 +643,7 @@ class ProcessedLedger(Base): __tablename__ = "processed_ledgers" id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - ledger_sequence: Mapped[int] = mapped_column(Integer, unique=True, nullable=False) +ledger_sequence: Mapped[int] = mapped_column(Integer, unique=True, nullable=False) source: Mapped[str] = mapped_column( String(256), nullable=False, @@ -559,9 +655,11 @@ class ProcessedLedger(Base): ) status: Mapped[ Literal["pending", "processing", "completed", "failed"] -] = mapped_column() - - +] = mapped_column( + String(16), + nullable=False, + server_default="pending", +) String(32), nullable=False, server_default="pending",