diff --git a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py index dbff05e4d6..de9379d096 100644 --- a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py +++ b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py @@ -39,7 +39,7 @@ from crewai.rag.embeddings.types import EmbedderConfig from crewai.security.security_config import SecurityConfig from crewai.skills.models import Skill -from crewai.state.checkpoint_config import CheckpointConfig +from crewai.state.checkpoint_config import CheckpointConfig, _coerce_checkpoint from crewai.tools.base_tool import BaseTool, Tool from crewai.types.callback import SerializableCallable from crewai.utilities.config import process_config @@ -300,7 +300,10 @@ def _validate_agent_executor(cls, v: Any) -> Any: default_factory=SecurityConfig, description="Security configuration for the agent, including fingerprinting.", ) - checkpoint: CheckpointConfig | bool | None = Field( + checkpoint: Annotated[ + CheckpointConfig | bool | None, + BeforeValidator(_coerce_checkpoint), + ] = Field( default=None, description="Automatic checkpointing configuration. " "True for defaults, False to opt out, None to inherit.", diff --git a/lib/crewai/src/crewai/crew.py b/lib/crewai/src/crewai/crew.py index 4f9ebab5dc..e630ec5b00 100644 --- a/lib/crewai/src/crewai/crew.py +++ b/lib/crewai/src/crewai/crew.py @@ -104,7 +104,7 @@ def get_supported_content_types(provider: str, api: str | None = None) -> list[s from crewai.security.fingerprint import Fingerprint from crewai.security.security_config import SecurityConfig from crewai.skills.models import Skill -from crewai.state.checkpoint_config import CheckpointConfig +from crewai.state.checkpoint_config import CheckpointConfig, _coerce_checkpoint from crewai.task import Task from crewai.tasks.conditional_task import ConditionalTask from crewai.tasks.task_output import TaskOutput @@ -341,7 +341,10 @@ class Crew(FlowTrackable, BaseModel): default_factory=SecurityConfig, description="Security configuration for the crew, including fingerprinting.", ) - checkpoint: CheckpointConfig | bool | None = Field( + checkpoint: Annotated[ + CheckpointConfig | bool | None, + BeforeValidator(_coerce_checkpoint), + ] = Field( default=None, description="Automatic checkpointing configuration. " "True for defaults, False to opt out, None to inherit.", diff --git a/lib/crewai/src/crewai/flow/flow.py b/lib/crewai/src/crewai/flow/flow.py index 76a96b3f94..60d03b0693 100644 --- a/lib/crewai/src/crewai/flow/flow.py +++ b/lib/crewai/src/crewai/flow/flow.py @@ -113,7 +113,7 @@ ) from crewai.memory.memory_scope import MemoryScope, MemorySlice from crewai.memory.unified_memory import Memory -from crewai.state.checkpoint_config import CheckpointConfig +from crewai.state.checkpoint_config import CheckpointConfig, _coerce_checkpoint if TYPE_CHECKING: @@ -921,7 +921,10 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): max_method_calls: int = Field(default=100) execution_context: ExecutionContext | None = Field(default=None) - checkpoint: CheckpointConfig | bool | None = Field(default=None) + checkpoint: Annotated[ + CheckpointConfig | bool | None, + BeforeValidator(_coerce_checkpoint), + ] = Field(default=None) @classmethod def from_checkpoint( diff --git a/lib/crewai/src/crewai/state/checkpoint_config.py b/lib/crewai/src/crewai/state/checkpoint_config.py index 4c5499ff45..84c48bd4e9 100644 --- a/lib/crewai/src/crewai/state/checkpoint_config.py +++ b/lib/crewai/src/crewai/state/checkpoint_config.py @@ -2,9 +2,9 @@ from __future__ import annotations -from typing import Literal +from typing import Any, Literal -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from crewai.state.provider.core import BaseProvider from crewai.state.provider.json_provider import JsonProvider @@ -158,6 +158,20 @@ ] +def _coerce_checkpoint(v: Any) -> Any: + """BeforeValidator for checkpoint fields on Crew/Flow/Agent. + + Converts True to CheckpointConfig and triggers handler registration. + """ + if v is True: + v = CheckpointConfig() + if isinstance(v, CheckpointConfig): + from crewai.state.checkpoint_listener import _ensure_handlers_registered + + _ensure_handlers_registered() + return v + + class CheckpointConfig(BaseModel): """Configuration for automatic checkpointing. @@ -185,6 +199,13 @@ class CheckpointConfig(BaseModel): "each write. None means keep all.", ) + @model_validator(mode="after") + def _register_handlers(self) -> CheckpointConfig: + from crewai.state.checkpoint_listener import _ensure_handlers_registered + + _ensure_handlers_registered() + return self + @property def trigger_all(self) -> bool: return "*" in self.on_events