diff --git a/aidefense/config.py b/aidefense/config.py index 79402ef..d2e58ce 100644 --- a/aidefense/config.py +++ b/aidefense/config.py @@ -83,6 +83,8 @@ def __new__(cls, *args, **kwargs): return cls._instances[cls] + _logger = logging.getLogger("aidefense_sdk.config") + def __init__(self, *args, **kwargs): # Double-checked locking: fast path avoids the lock for already-init'd # singletons; the lock prevents concurrent first-time callers from both @@ -96,6 +98,8 @@ def __init__(self, *args, **kwargs): except Exception: self._instances.pop(type(self), None) raise + elif args or kwargs: + self._warn_if_params_differ(*args, **kwargs) def _set_region(self, region: str): if not isinstance(region, str): @@ -178,6 +182,53 @@ def _set_pool_config(self, pool_config: dict): "pool_maxsize": pool_config.get("pool_maxsize", self.DEFAULT_POOL_MAXSIZE), } + _INIT_PARAM_NAMES = ( + "region", "runtime_base_url", "management_base_url", "timeout", + ) + + def _warn_if_params_differ(self, *args, **kwargs): + """Log a warning when the singleton is re-requested with different parameters.""" + merged = dict(zip(self._INIT_PARAM_NAMES, args)) + merged.update(kwargs) + + _NORMALIZERS = { + "region": self._normalize_requested_region, + "runtime_base_url": self._normalize_url, + "management_base_url": self._normalize_url, + } + diffs = [] + for key in self._INIT_PARAM_NAMES: + if key not in merged or merged[key] is None: + continue + requested = merged[key] + normalizer = _NORMALIZERS.get(key) + if normalizer: + requested = normalizer(requested) + current = getattr(self, key, None) + if current is not None and requested != current: + diffs.append(f"{key}={current!r} (requested {merged[key]!r})") + if diffs: + self._logger.warning( + "%s singleton already initialized. Ignoring different " + "parameters: %s. Construct %s once and share it, or clear " + "%s._instances to re-initialize.", + type(self).__name__, + ", ".join(diffs), + type(self).__name__, + type(self).__name__, + ) + + @staticmethod + def _normalize_requested_region(region): + """Map short-code regions to canonical names for comparison.""" + _SHORT = {"us": "us-west-2", "eu": "eu-central-1", "apj": "ap-northeast-1"} + return _SHORT.get(region, region) if isinstance(region, str) else region + + @staticmethod + def _normalize_url(url): + """Strip trailing slash to match stored value normalization.""" + return url.rstrip("/") if isinstance(url, str) else url + @abstractmethod def _initialize(self, *args, **kwargs): pass diff --git a/aidefense/tests/test_config.py b/aidefense/tests/test_config.py index 4618f41..5b51d80 100644 --- a/aidefense/tests/test_config.py +++ b/aidefense/tests/test_config.py @@ -75,3 +75,59 @@ def test_config_with_pool_config(): config = Config(pool_config=pool_conf) assert config.connection_pool._pool_connections == 3 assert config.connection_pool._pool_maxsize == 7 + + +def test_config_warns_on_different_region(caplog): + Config(region="us-west-2") + with caplog.at_level(logging.WARNING, logger="aidefense_sdk.config"): + Config(region="eu-central-1") + assert "already initialized" in caplog.text + assert "region" in caplog.text + assert "eu-central-1" in caplog.text + + +def test_config_warns_on_different_timeout(caplog): + Config(timeout=30) + with caplog.at_level(logging.WARNING, logger="aidefense_sdk.config"): + Config(timeout=60) + assert "already initialized" in caplog.text + assert "timeout" in caplog.text + + +def test_config_no_warning_on_same_params(caplog): + Config(region="us-west-2", timeout=30) + with caplog.at_level(logging.WARNING, logger="aidefense_sdk.config"): + Config(region="us-west-2", timeout=30) + assert "already initialized" not in caplog.text + + +def test_config_no_warning_on_short_region_alias(caplog): + """'us' normalizes to 'us-west-2', so no mismatch warning.""" + Config(region="us-west-2") + with caplog.at_level(logging.WARNING, logger="aidefense_sdk.config"): + Config(region="us") + assert "already initialized" not in caplog.text + + +def test_config_no_warning_when_no_kwargs(caplog): + Config(region="eu-central-1") + with caplog.at_level(logging.WARNING, logger="aidefense_sdk.config"): + Config() + assert "already initialized" not in caplog.text + + +def test_config_warns_on_positional_region_mismatch(caplog): + """Positional args (not just kwargs) must trigger the warning.""" + Config("us-west-2") + with caplog.at_level(logging.WARNING, logger="aidefense_sdk.config"): + Config("eu-central-1") + assert "already initialized" in caplog.text + assert "region" in caplog.text + + +def test_config_no_false_positive_on_trailing_slash_url(caplog): + """URLs with trailing slashes should be normalized before comparison.""" + Config(runtime_base_url="https://custom.endpoint.com/") + with caplog.at_level(logging.WARNING, logger="aidefense_sdk.config"): + Config(runtime_base_url="https://custom.endpoint.com/") + assert "already initialized" not in caplog.text