From 45e05e226e208d49fd2a1f74e5b721e0410b2db5 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Wed, 8 Apr 2026 02:34:37 +0800 Subject: [PATCH] refactor: make BaseProvider a BaseModel with provider_type discriminator Replace the Protocol with a BaseModel + ABC so providers serialize and deserialize natively via pydantic. Each provider gets a Literal provider_type field. CheckpointConfig.provider uses a discriminated union so the correct provider class is reconstructed from checkpoint JSON. --- .../src/crewai/state/checkpoint_config.py | 9 +++-- lib/crewai/src/crewai/state/provider/core.py | 35 ++++++------------- .../crewai/state/provider/json_provider.py | 3 ++ .../crewai/state/provider/sqlite_provider.py | 3 ++ 4 files changed, 23 insertions(+), 27 deletions(-) diff --git a/lib/crewai/src/crewai/state/checkpoint_config.py b/lib/crewai/src/crewai/state/checkpoint_config.py index 84c48bd4e9..38c6b0490d 100644 --- a/lib/crewai/src/crewai/state/checkpoint_config.py +++ b/lib/crewai/src/crewai/state/checkpoint_config.py @@ -2,12 +2,12 @@ from __future__ import annotations -from typing import Any, Literal +from typing import Annotated, Any, Literal from pydantic import BaseModel, Field, model_validator -from crewai.state.provider.core import BaseProvider from crewai.state.provider.json_provider import JsonProvider +from crewai.state.provider.sqlite_provider import SqliteProvider CheckpointEventType = Literal[ @@ -189,7 +189,10 @@ class CheckpointConfig(BaseModel): description="Event types that trigger a checkpoint write. " 'Use ["*"] to checkpoint on every event.', ) - provider: BaseProvider = Field( + provider: Annotated[ + JsonProvider | SqliteProvider, + Field(discriminator="provider_type"), + ] = Field( default_factory=JsonProvider, description="Storage backend. Defaults to JsonProvider.", ) diff --git a/lib/crewai/src/crewai/state/provider/core.py b/lib/crewai/src/crewai/state/provider/core.py index 46f079444c..0b12364c08 100644 --- a/lib/crewai/src/crewai/state/provider/core.py +++ b/lib/crewai/src/crewai/state/provider/core.py @@ -1,39 +1,22 @@ -"""Base protocol for state providers.""" +"""Base class for state providers.""" from __future__ import annotations -from typing import Any, Protocol, runtime_checkable +from abc import ABC, abstractmethod -from pydantic import GetCoreSchemaHandler -from pydantic_core import CoreSchema, core_schema +from pydantic import BaseModel -@runtime_checkable -class BaseProvider(Protocol): - """Interface for persisting and restoring runtime state checkpoints. +class BaseProvider(BaseModel, ABC): + """Base class for persisting and restoring runtime state checkpoints. Implementations handle the storage backend — filesystem, cloud, database, etc. — while ``RuntimeState`` handles serialization. """ - @classmethod - def __get_pydantic_core_schema__( - cls, source_type: Any, handler: GetCoreSchemaHandler - ) -> CoreSchema: - """Allow Pydantic to validate any ``BaseProvider`` instance.""" - - def _validate(v: Any) -> BaseProvider: - if isinstance(v, BaseProvider): - return v - raise TypeError(f"Expected a BaseProvider instance, got {type(v)}") - - return core_schema.no_info_plain_validator_function( - _validate, - serialization=core_schema.plain_serializer_function_ser_schema( - lambda v: type(v).__name__, info_arg=False - ), - ) + provider_type: str = "base" + @abstractmethod def checkpoint(self, data: str, location: str) -> str: """Persist a snapshot synchronously. @@ -46,6 +29,7 @@ def checkpoint(self, data: str, location: str) -> str: """ ... + @abstractmethod async def acheckpoint(self, data: str, location: str) -> str: """Persist a snapshot asynchronously. @@ -58,6 +42,7 @@ async def acheckpoint(self, data: str, location: str) -> str: """ ... + @abstractmethod def prune(self, location: str, max_keep: int) -> None: """Remove old checkpoints, keeping at most *max_keep*. @@ -67,6 +52,7 @@ def prune(self, location: str, max_keep: int) -> None: """ ... + @abstractmethod def from_checkpoint(self, location: str) -> str: """Read a snapshot synchronously. @@ -78,6 +64,7 @@ def from_checkpoint(self, location: str) -> str: """ ... + @abstractmethod async def afrom_checkpoint(self, location: str) -> str: """Read a snapshot asynchronously. diff --git a/lib/crewai/src/crewai/state/provider/json_provider.py b/lib/crewai/src/crewai/state/provider/json_provider.py index d2ac75d9c2..f9763e6f3c 100644 --- a/lib/crewai/src/crewai/state/provider/json_provider.py +++ b/lib/crewai/src/crewai/state/provider/json_provider.py @@ -7,6 +7,7 @@ import logging import os from pathlib import Path +from typing import Literal import uuid import aiofiles @@ -21,6 +22,8 @@ class JsonProvider(BaseProvider): """Persists runtime state checkpoints as JSON files on the local filesystem.""" + provider_type: Literal["json"] = "json" + def checkpoint(self, data: str, location: str) -> str: """Write a JSON checkpoint file. diff --git a/lib/crewai/src/crewai/state/provider/sqlite_provider.py b/lib/crewai/src/crewai/state/provider/sqlite_provider.py index ae014dda35..e54f561804 100644 --- a/lib/crewai/src/crewai/state/provider/sqlite_provider.py +++ b/lib/crewai/src/crewai/state/provider/sqlite_provider.py @@ -5,6 +5,7 @@ from datetime import datetime, timezone from pathlib import Path import sqlite3 +from typing import Literal import uuid import aiosqlite @@ -47,6 +48,8 @@ class SqliteProvider(BaseProvider): used as the database file path. """ + provider_type: Literal["sqlite"] = "sqlite" + def checkpoint(self, data: str, location: str) -> str: """Write a checkpoint to the SQLite database.