diff --git a/config/config_manager.py b/config/config_manager.py index 445ee9e..ce601fa 100644 --- a/config/config_manager.py +++ b/config/config_manager.py @@ -28,7 +28,7 @@ from typing import Any, Dict, Optional import yaml -from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator +from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator, model_validator # ============================================================================ # PYDANTIC MODELS FOR TYPE-SAFE CONFIGURATION @@ -61,10 +61,10 @@ class WorkspaceBounds(BaseModel): model_config = ConfigDict(extra="ignore") - x_min: float = Field(..., description="Minimum X coordinate (meters)") - x_max: float = Field(..., description="Maximum X coordinate (meters)") - y_min: float = Field(..., description="Minimum Y coordinate (meters)") - y_max: float = Field(..., description="Maximum Y coordinate (meters)") + x_min: float = Field(0.0, description="Minimum X coordinate (meters)") + x_max: float = Field(0.1, description="Maximum X coordinate (meters)") + y_min: float = Field(0.0, description="Minimum Y coordinate (meters)") + y_max: float = Field(0.1, description="Maximum Y coordinate (meters)") @field_validator("x_max") @classmethod @@ -88,9 +88,9 @@ class WorkspaceConfig(BaseModel): model_config = ConfigDict(extra="ignore") - id: str = Field(..., description="Workspace identifier") - bounds: WorkspaceBounds = Field(..., description="Workspace boundaries") - center: list[float] = Field(..., description="Workspace center [x, y]") + id: str = Field("", description="Workspace identifier") + bounds: WorkspaceBounds = Field(default_factory=WorkspaceBounds, description="Workspace boundaries") + center: list[float] = Field(default_factory=lambda: [0.0, 0.0], description="Workspace center [x, y]") @field_validator("center") @classmethod @@ -125,8 +125,8 @@ class RobotConfig(BaseModel): verbose: bool = Field(False, description="Enable verbose output") enable_camera: bool = Field(True, description="Enable camera") camera_update_rate_hz: float = Field(2.0, ge=0.1, le=30.0, description="Camera update rate (Hz)") - workspace: Dict[str, WorkspaceConfig] = Field(..., description="Workspace configurations") - motion: MotionConfig = Field(..., description="Motion parameters") + workspace: Dict[str, WorkspaceConfig] = Field(default_factory=dict, description="Workspace configurations") + motion: MotionConfig = Field(default_factory=MotionConfig, description="Motion parameters") @field_validator("type") @classmethod @@ -159,7 +159,7 @@ class DetectionConfig(BaseModel): iou_threshold: float = Field(0.5, ge=0.0, le=1.0, description="IoU threshold") max_detections: int = Field(100, ge=1, le=1000, description="Maximum detections") default_labels: list[str] = Field(default_factory=list, description="Default object labels") - spatial: SpatialConfig = Field(..., description="Spatial query thresholds") + spatial: SpatialConfig = Field(default_factory=SpatialConfig, description="Spatial query thresholds") @field_validator("model") @classmethod @@ -186,7 +186,7 @@ class LLMProviderConfig(BaseModel): model_config = ConfigDict(extra="ignore") enabled: bool = Field(True, description="Enable this provider") - default_model: str = Field(..., description="Default model name") + default_model: str = Field("unknown", description="Default model name") models: list[str] = Field(default_factory=list, description="Available models") rate_limit_rpm: Optional[int] = Field(None, ge=1, description="Rate limit (requests/min)") base_url: Optional[str] = Field(None, description="Base URL for API") @@ -203,7 +203,7 @@ class LLMConfig(BaseModel): enable_cot: bool = Field(True, description="Enable chain-of-thought") require_planning: bool = Field(True, description="Require planning phase") max_iterations: int = Field(15, ge=1, le=50, description="Max tool-calling iterations") - providers: Dict[str, LLMProviderConfig] = Field(..., description="Provider configurations") + providers: Dict[str, LLMProviderConfig] = Field(default_factory=dict, description="Provider configurations") @field_validator("default_provider") @classmethod @@ -266,7 +266,7 @@ class RedisConfig(BaseModel): port: int = Field(6379, ge=1, le=65535, description="Redis port") db: int = Field(0, ge=0, le=15, description="Redis database number") decode_responses: bool = Field(True, description="Decode responses") - streams: RedisStreamsConfig = Field(..., description="Stream names") + streams: RedisStreamsConfig = Field(default_factory=RedisStreamsConfig, description="Stream names") class GUIConfig(BaseModel): @@ -296,10 +296,10 @@ class LoggingConfig(BaseModel): model_config = ConfigDict(extra="ignore") - format: str = Field(..., description="Log message format") + format: str = Field("%(asctime)s - %(name)s - %(levelname)s - %(message)s", description="Log message format") date_format: str = Field("%Y-%m-%d %H:%M:%S", description="Date format") - levels: Dict[str, str] = Field(..., description="Log levels by module") - rotation: LogRotationConfig = Field(..., description="Log rotation settings") + levels: Dict[str, str] = Field(default_factory=dict, description="Log levels by module") + rotation: LogRotationConfig = Field(default_factory=LogRotationConfig, description="Log rotation settings") class EnvironmentOverrides(BaseModel): @@ -332,6 +332,34 @@ class RobotMCPConfig(BaseModel): logging: LoggingConfig environments: Optional[Dict[str, EnvironmentOverrides]] = None + @model_validator(mode="after") + def validate_merged_config(self) -> "RobotMCPConfig": + """ + Validate that the final merged configuration is complete. + + Ensures that essential fields are present after merging overrides. + """ + # 1. Robot must have workspaces + if not self.robot.workspace: + raise ValueError("Configuration error: No workspaces defined in 'robot.workspace'") + + # 2. Each workspace must have a valid ID and non-zero bounds + for ws_id, ws in self.robot.workspace.items(): + if not ws.id: + raise ValueError(f"Configuration error: Workspace '{ws_id}' has no 'id'") + if ws.bounds.x_min == ws.bounds.x_max and ws.bounds.y_min == ws.bounds.y_max: + raise ValueError(f"Configuration error: Workspace '{ws_id}' has zero-area bounds") + + # 3. LLM must have at least one provider + if not self.llm.providers: + raise ValueError("Configuration error: No LLM providers defined in 'llm.providers'") + + # 4. Redis must have stream names + if not self.redis.streams.camera or not self.redis.streams.detected_objects: + raise ValueError("Configuration error: Missing Redis stream names") + + return self + # ============================================================================ # CONFIGURATION MANAGER diff --git a/tests/test_config_manager.py b/tests/test_config_manager.py new file mode 100644 index 0000000..26a5281 --- /dev/null +++ b/tests/test_config_manager.py @@ -0,0 +1,135 @@ +"""Unit tests for configuration management.""" + +import os + +import pytest +import yaml + +from config.config_manager import ConfigManager + + +def test_config_load_default(tmp_path): + """Test loading default configuration.""" + # Create a minimal valid config + config_data = { + "server": {"host": "127.0.0.1", "port": 8000}, + "robot": { + "type": "niryo", + "workspace": { + "niryo": { + "id": "niryo_ws", + "bounds": {"x_min": 0.1, "x_max": 0.2, "y_min": 0.1, "y_max": 0.2}, + "center": [0.15, 0.15], + } + }, + "motion": {}, + }, + "detection": {"spatial": {}}, + "llm": {"providers": {"openai": {"default_model": "gpt-4o"}}}, + "tts": {}, + "redis": {"streams": {}}, + "gui": {}, + "logging": {"format": "%(message)s", "levels": {}, "rotation": {}}, + } + + config_file = tmp_path / "robot_config.yaml" + with open(config_file, "w") as f: + yaml.dump(config_data, f) + + config = ConfigManager.load(config_path=str(config_file)) + assert config.server.host == "127.0.0.1" + assert config.robot.type == "niryo" + assert "niryo" in config.robot.workspace + + +def test_config_with_environment_override(tmp_path): + """Test environment-specific overrides.""" + config_data = { + "server": {"host": "127.0.0.1", "port": 8000}, + "robot": { + "type": "niryo", + "workspace": { + "niryo": { + "id": "niryo_ws", + "bounds": {"x_min": 0.1, "x_max": 0.2, "y_min": 0.1, "y_max": 0.2}, + "center": [0.15, 0.15], + } + }, + "motion": {}, + }, + "detection": {"spatial": {}}, + "llm": {"providers": {"openai": {"default_model": "gpt-4o"}}}, + "tts": {}, + "redis": {"streams": {}}, + "gui": {}, + "logging": {"format": "%(message)s", "levels": {}, "rotation": {}}, + "environments": {"development": {"server": {"log_level": "DEBUG"}, "robot": {"simulation": True}}}, + } + + config_file = tmp_path / "robot_config.yaml" + with open(config_file, "w") as f: + yaml.dump(config_data, f) + + # Test partial override in development environment + os.environ["ROBOT_ENV"] = "development" + config = ConfigManager.load(config_path=str(config_file)) + + assert config.server.log_level == "DEBUG" + assert config.robot.simulation is True + # Ensure non-overridden fields are still there + assert config.server.host == "127.0.0.1" + assert "niryo" in config.robot.workspace + + +def test_config_invalid_missing_workspace(tmp_path): + """Test missing workspace validation.""" + config_data = { + "server": {"host": "127.0.0.1"}, + "robot": {"type": "niryo", "workspace": {}}, # Empty workspace + "detection": {"spatial": {}}, + "llm": {"providers": {"openai": {"default_model": "gpt-4o"}}}, + "tts": {}, + "redis": {"streams": {}}, + "gui": {}, + "logging": {"format": "%(message)s", "levels": {}, "rotation": {}}, + } + + config_file = tmp_path / "robot_config.yaml" + with open(config_file, "w") as f: + yaml.dump(config_data, f) + + with pytest.raises(Exception) as excinfo: + ConfigManager.load(config_path=str(config_file)) + assert "No workspaces defined" in str(excinfo.value) + + +def test_config_invalid_zero_area_workspace(tmp_path): + """Test zero-area workspace validation.""" + config_data = { + "server": {"host": "127.0.0.1"}, + "robot": { + "type": "niryo", + "workspace": { + "niryo": { + "id": "niryo_ws", + "bounds": {"x_min": 0.1, "x_max": 0.1, "y_min": 0.1, "y_max": 0.1}, # Zero area + "center": [0.1, 0.1], + } + }, + }, + "detection": {"spatial": {}}, + "llm": {"providers": {"openai": {"default_model": "gpt-4o"}}}, + "tts": {}, + "redis": {"streams": {}}, + "gui": {}, + "logging": {"format": "%(message)s", "levels": {}, "rotation": {}}, + } + + config_file = tmp_path / "robot_config.yaml" + with open(config_file, "w") as f: + yaml.dump(config_data, f) + + with pytest.raises(Exception) as excinfo: + ConfigManager.load(config_path=str(config_file)) + # It might be caught by field_validator or model_validator + assert any(msg in str(excinfo.value) for msg in ["zero-area bounds", "must be greater than"])