Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions aidefense/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def __new__(cls, *args, **kwargs):

return cls._instances[cls]

_logger = logging.getLogger("aidefense_sdk.config")

def __init__(self, *args, **kwargs):
# Double-checked locking: fast path avoids the lock for already-init'd
# singletons; the lock prevents concurrent first-time callers from both
Expand All @@ -96,6 +98,8 @@ def __init__(self, *args, **kwargs):
except Exception:
self._instances.pop(type(self), None)
raise
elif args or kwargs:
self._warn_if_params_differ(*args, **kwargs)

def _set_region(self, region: str):
if not isinstance(region, str):
Expand Down Expand Up @@ -178,6 +182,53 @@ def _set_pool_config(self, pool_config: dict):
"pool_maxsize": pool_config.get("pool_maxsize", self.DEFAULT_POOL_MAXSIZE),
}

_INIT_PARAM_NAMES = (
"region", "runtime_base_url", "management_base_url", "timeout",
)

def _warn_if_params_differ(self, *args, **kwargs):
"""Log a warning when the singleton is re-requested with different parameters."""
merged = dict(zip(self._INIT_PARAM_NAMES, args))
merged.update(kwargs)

_NORMALIZERS = {
"region": self._normalize_requested_region,
"runtime_base_url": self._normalize_url,
"management_base_url": self._normalize_url,
}
diffs = []
for key in self._INIT_PARAM_NAMES:
if key not in merged or merged[key] is None:
continue
requested = merged[key]
normalizer = _NORMALIZERS.get(key)
if normalizer:
requested = normalizer(requested)
current = getattr(self, key, None)
if current is not None and requested != current:
diffs.append(f"{key}={current!r} (requested {merged[key]!r})")
if diffs:
self._logger.warning(
"%s singleton already initialized. Ignoring different "
"parameters: %s. Construct %s once and share it, or clear "
"%s._instances to re-initialize.",
type(self).__name__,
", ".join(diffs),
type(self).__name__,
type(self).__name__,
)

@staticmethod
def _normalize_requested_region(region):
"""Map short-code regions to canonical names for comparison."""
_SHORT = {"us": "us-west-2", "eu": "eu-central-1", "apj": "ap-northeast-1"}
return _SHORT.get(region, region) if isinstance(region, str) else region

@staticmethod
def _normalize_url(url):
"""Strip trailing slash to match stored value normalization."""
return url.rstrip("/") if isinstance(url, str) else url

@abstractmethod
def _initialize(self, *args, **kwargs):
pass
Expand Down
56 changes: 56 additions & 0 deletions aidefense/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,59 @@ def test_config_with_pool_config():
config = Config(pool_config=pool_conf)
assert config.connection_pool._pool_connections == 3
assert config.connection_pool._pool_maxsize == 7


def test_config_warns_on_different_region(caplog):
Config(region="us-west-2")
with caplog.at_level(logging.WARNING, logger="aidefense_sdk.config"):
Config(region="eu-central-1")
assert "already initialized" in caplog.text
assert "region" in caplog.text
assert "eu-central-1" in caplog.text


def test_config_warns_on_different_timeout(caplog):
Config(timeout=30)
with caplog.at_level(logging.WARNING, logger="aidefense_sdk.config"):
Config(timeout=60)
assert "already initialized" in caplog.text
assert "timeout" in caplog.text


def test_config_no_warning_on_same_params(caplog):
Config(region="us-west-2", timeout=30)
with caplog.at_level(logging.WARNING, logger="aidefense_sdk.config"):
Config(region="us-west-2", timeout=30)
assert "already initialized" not in caplog.text


def test_config_no_warning_on_short_region_alias(caplog):
"""'us' normalizes to 'us-west-2', so no mismatch warning."""
Config(region="us-west-2")
with caplog.at_level(logging.WARNING, logger="aidefense_sdk.config"):
Config(region="us")
assert "already initialized" not in caplog.text


def test_config_no_warning_when_no_kwargs(caplog):
Config(region="eu-central-1")
with caplog.at_level(logging.WARNING, logger="aidefense_sdk.config"):
Config()
assert "already initialized" not in caplog.text


def test_config_warns_on_positional_region_mismatch(caplog):
"""Positional args (not just kwargs) must trigger the warning."""
Config("us-west-2")
with caplog.at_level(logging.WARNING, logger="aidefense_sdk.config"):
Config("eu-central-1")
assert "already initialized" in caplog.text
assert "region" in caplog.text


def test_config_no_false_positive_on_trailing_slash_url(caplog):
"""URLs with trailing slashes should be normalized before comparison."""
Config(runtime_base_url="https://custom.endpoint.com/")
with caplog.at_level(logging.WARNING, logger="aidefense_sdk.config"):
Config(runtime_base_url="https://custom.endpoint.com/")
assert "already initialized" not in caplog.text
Loading