Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 45 additions & 17 deletions config/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from typing import Any, Dict, Optional

import yaml
from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator
from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator, model_validator

# ============================================================================
# PYDANTIC MODELS FOR TYPE-SAFE CONFIGURATION
Expand Down Expand Up @@ -61,10 +61,10 @@ class WorkspaceBounds(BaseModel):

model_config = ConfigDict(extra="ignore")

x_min: float = Field(..., description="Minimum X coordinate (meters)")
x_max: float = Field(..., description="Maximum X coordinate (meters)")
y_min: float = Field(..., description="Minimum Y coordinate (meters)")
y_max: float = Field(..., description="Maximum Y coordinate (meters)")
x_min: float = Field(0.0, description="Minimum X coordinate (meters)")
x_max: float = Field(0.1, description="Maximum X coordinate (meters)")
y_min: float = Field(0.0, description="Minimum Y coordinate (meters)")
y_max: float = Field(0.1, description="Maximum Y coordinate (meters)")

@field_validator("x_max")
@classmethod
Expand All @@ -88,9 +88,9 @@ class WorkspaceConfig(BaseModel):

model_config = ConfigDict(extra="ignore")

id: str = Field(..., description="Workspace identifier")
bounds: WorkspaceBounds = Field(..., description="Workspace boundaries")
center: list[float] = Field(..., description="Workspace center [x, y]")
id: str = Field("", description="Workspace identifier")
bounds: WorkspaceBounds = Field(default_factory=WorkspaceBounds, description="Workspace boundaries")
center: list[float] = Field(default_factory=lambda: [0.0, 0.0], description="Workspace center [x, y]")

@field_validator("center")
@classmethod
Expand Down Expand Up @@ -125,8 +125,8 @@ class RobotConfig(BaseModel):
verbose: bool = Field(False, description="Enable verbose output")
enable_camera: bool = Field(True, description="Enable camera")
camera_update_rate_hz: float = Field(2.0, ge=0.1, le=30.0, description="Camera update rate (Hz)")
workspace: Dict[str, WorkspaceConfig] = Field(..., description="Workspace configurations")
motion: MotionConfig = Field(..., description="Motion parameters")
workspace: Dict[str, WorkspaceConfig] = Field(default_factory=dict, description="Workspace configurations")
motion: MotionConfig = Field(default_factory=MotionConfig, description="Motion parameters")

@field_validator("type")
@classmethod
Expand Down Expand Up @@ -159,7 +159,7 @@ class DetectionConfig(BaseModel):
iou_threshold: float = Field(0.5, ge=0.0, le=1.0, description="IoU threshold")
max_detections: int = Field(100, ge=1, le=1000, description="Maximum detections")
default_labels: list[str] = Field(default_factory=list, description="Default object labels")
spatial: SpatialConfig = Field(..., description="Spatial query thresholds")
spatial: SpatialConfig = Field(default_factory=SpatialConfig, description="Spatial query thresholds")

@field_validator("model")
@classmethod
Expand All @@ -186,7 +186,7 @@ class LLMProviderConfig(BaseModel):
model_config = ConfigDict(extra="ignore")

enabled: bool = Field(True, description="Enable this provider")
default_model: str = Field(..., description="Default model name")
default_model: str = Field("unknown", description="Default model name")
models: list[str] = Field(default_factory=list, description="Available models")
rate_limit_rpm: Optional[int] = Field(None, ge=1, description="Rate limit (requests/min)")
base_url: Optional[str] = Field(None, description="Base URL for API")
Expand All @@ -203,7 +203,7 @@ class LLMConfig(BaseModel):
enable_cot: bool = Field(True, description="Enable chain-of-thought")
require_planning: bool = Field(True, description="Require planning phase")
max_iterations: int = Field(15, ge=1, le=50, description="Max tool-calling iterations")
providers: Dict[str, LLMProviderConfig] = Field(..., description="Provider configurations")
providers: Dict[str, LLMProviderConfig] = Field(default_factory=dict, description="Provider configurations")

@field_validator("default_provider")
@classmethod
Expand Down Expand Up @@ -266,7 +266,7 @@ class RedisConfig(BaseModel):
port: int = Field(6379, ge=1, le=65535, description="Redis port")
db: int = Field(0, ge=0, le=15, description="Redis database number")
decode_responses: bool = Field(True, description="Decode responses")
streams: RedisStreamsConfig = Field(..., description="Stream names")
streams: RedisStreamsConfig = Field(default_factory=RedisStreamsConfig, description="Stream names")


class GUIConfig(BaseModel):
Expand Down Expand Up @@ -296,10 +296,10 @@ class LoggingConfig(BaseModel):

model_config = ConfigDict(extra="ignore")

format: str = Field(..., description="Log message format")
format: str = Field("%(asctime)s - %(name)s - %(levelname)s - %(message)s", description="Log message format")
date_format: str = Field("%Y-%m-%d %H:%M:%S", description="Date format")
levels: Dict[str, str] = Field(..., description="Log levels by module")
rotation: LogRotationConfig = Field(..., description="Log rotation settings")
levels: Dict[str, str] = Field(default_factory=dict, description="Log levels by module")
rotation: LogRotationConfig = Field(default_factory=LogRotationConfig, description="Log rotation settings")


class EnvironmentOverrides(BaseModel):
Expand Down Expand Up @@ -332,6 +332,34 @@ class RobotMCPConfig(BaseModel):
logging: LoggingConfig
environments: Optional[Dict[str, EnvironmentOverrides]] = None

@model_validator(mode="after")
def validate_merged_config(self) -> "RobotMCPConfig":
"""
Validate that the final merged configuration is complete.

Ensures that essential fields are present after merging overrides.
"""
# 1. Robot must have workspaces
if not self.robot.workspace:
raise ValueError("Configuration error: No workspaces defined in 'robot.workspace'")

# 2. Each workspace must have a valid ID and non-zero bounds
for ws_id, ws in self.robot.workspace.items():
if not ws.id:
raise ValueError(f"Configuration error: Workspace '{ws_id}' has no 'id'")
if ws.bounds.x_min == ws.bounds.x_max and ws.bounds.y_min == ws.bounds.y_max:
raise ValueError(f"Configuration error: Workspace '{ws_id}' has zero-area bounds")

# 3. LLM must have at least one provider
if not self.llm.providers:
raise ValueError("Configuration error: No LLM providers defined in 'llm.providers'")

# 4. Redis must have stream names
if not self.redis.streams.camera or not self.redis.streams.detected_objects:
raise ValueError("Configuration error: Missing Redis stream names")

return self


# ============================================================================
# CONFIGURATION MANAGER
Expand Down
135 changes: 135 additions & 0 deletions tests/test_config_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""Unit tests for configuration management."""

import os

import pytest
import yaml

from config.config_manager import ConfigManager


def test_config_load_default(tmp_path):
"""Test loading default configuration."""
# Create a minimal valid config
config_data = {
"server": {"host": "127.0.0.1", "port": 8000},
"robot": {
"type": "niryo",
"workspace": {
"niryo": {
"id": "niryo_ws",
"bounds": {"x_min": 0.1, "x_max": 0.2, "y_min": 0.1, "y_max": 0.2},
"center": [0.15, 0.15],
}
},
"motion": {},
},
"detection": {"spatial": {}},
"llm": {"providers": {"openai": {"default_model": "gpt-4o"}}},
"tts": {},
"redis": {"streams": {}},
"gui": {},
"logging": {"format": "%(message)s", "levels": {}, "rotation": {}},
}

config_file = tmp_path / "robot_config.yaml"
with open(config_file, "w") as f:
yaml.dump(config_data, f)

config = ConfigManager.load(config_path=str(config_file))
assert config.server.host == "127.0.0.1"
assert config.robot.type == "niryo"
assert "niryo" in config.robot.workspace


def test_config_with_environment_override(tmp_path):
"""Test environment-specific overrides."""
config_data = {
"server": {"host": "127.0.0.1", "port": 8000},
"robot": {
"type": "niryo",
"workspace": {
"niryo": {
"id": "niryo_ws",
"bounds": {"x_min": 0.1, "x_max": 0.2, "y_min": 0.1, "y_max": 0.2},
"center": [0.15, 0.15],
}
},
"motion": {},
},
"detection": {"spatial": {}},
"llm": {"providers": {"openai": {"default_model": "gpt-4o"}}},
"tts": {},
"redis": {"streams": {}},
"gui": {},
"logging": {"format": "%(message)s", "levels": {}, "rotation": {}},
"environments": {"development": {"server": {"log_level": "DEBUG"}, "robot": {"simulation": True}}},
}

config_file = tmp_path / "robot_config.yaml"
with open(config_file, "w") as f:
yaml.dump(config_data, f)

# Test partial override in development environment
os.environ["ROBOT_ENV"] = "development"
config = ConfigManager.load(config_path=str(config_file))

assert config.server.log_level == "DEBUG"
assert config.robot.simulation is True
# Ensure non-overridden fields are still there
assert config.server.host == "127.0.0.1"
assert "niryo" in config.robot.workspace


def test_config_invalid_missing_workspace(tmp_path):
"""Test missing workspace validation."""
config_data = {
"server": {"host": "127.0.0.1"},
"robot": {"type": "niryo", "workspace": {}}, # Empty workspace
"detection": {"spatial": {}},
"llm": {"providers": {"openai": {"default_model": "gpt-4o"}}},
"tts": {},
"redis": {"streams": {}},
"gui": {},
"logging": {"format": "%(message)s", "levels": {}, "rotation": {}},
}

config_file = tmp_path / "robot_config.yaml"
with open(config_file, "w") as f:
yaml.dump(config_data, f)

with pytest.raises(Exception) as excinfo:
ConfigManager.load(config_path=str(config_file))
assert "No workspaces defined" in str(excinfo.value)


def test_config_invalid_zero_area_workspace(tmp_path):
"""Test zero-area workspace validation."""
config_data = {
"server": {"host": "127.0.0.1"},
"robot": {
"type": "niryo",
"workspace": {
"niryo": {
"id": "niryo_ws",
"bounds": {"x_min": 0.1, "x_max": 0.1, "y_min": 0.1, "y_max": 0.1}, # Zero area
"center": [0.1, 0.1],
}
},
},
"detection": {"spatial": {}},
"llm": {"providers": {"openai": {"default_model": "gpt-4o"}}},
"tts": {},
"redis": {"streams": {}},
"gui": {},
"logging": {"format": "%(message)s", "levels": {}, "rotation": {}},
}

config_file = tmp_path / "robot_config.yaml"
with open(config_file, "w") as f:
yaml.dump(config_data, f)

with pytest.raises(Exception) as excinfo:
ConfigManager.load(config_path=str(config_file))
# It might be caught by field_validator or model_validator
assert any(msg in str(excinfo.value) for msg in ["zero-area bounds", "must be greater than"])
Loading