diff --git a/freerelay/core/resilience/budget.py b/freerelay/core/resilience/budget.py index 7edbf8b..d68a0b3 100644 --- a/freerelay/core/resilience/budget.py +++ b/freerelay/core/resilience/budget.py @@ -12,6 +12,14 @@ from dataclasses import dataclass, field +def _next_midnight_utc(now: float) -> float: + """Return the unix timestamp of the next 00:00 UTC strictly after `now`.""" + # 86400 seconds in a day; compute seconds since the most recent midnight UTC + seconds_in_day = 86_400.0 + secs_today = now % seconds_in_day + return now + (seconds_in_day - secs_today) + + @dataclass class BudgetState: """Token budget state for a single provider/key.""" @@ -20,7 +28,7 @@ class BudgetState: tokens_used_this_minute: int = 0 ewma_rate: float = 0.0 # tokens per minute last_updated_ts: float = field(default_factory=time.time) - daily_reset_ts: float = 0.0 # next midnight UTC + daily_reset_ts: float = field(default_factory=lambda: _next_midnight_utc(time.time())) daily_limit: int | None = None # None = unlimited @@ -60,6 +68,14 @@ def record_tokens(self, provider: str, tokens: int) -> None: state = self._get_state(provider) now = time.time() + # Check if we've crossed a daily reset boundary (midnight UTC). + # Without this, tokens_used_today keeps accumulating across days + # and is_budget_exhausted() reports the provider as exhausted + # forever once the limit is first hit. + if now >= state.daily_reset_ts: + state.tokens_used_today = 0 + state.daily_reset_ts = _next_midnight_utc(now) + # Check if we've crossed one or more minute boundaries elapsed = now - state.last_updated_ts if elapsed >= 60: @@ -132,3 +148,4 @@ def reset_daily(self, provider: str) -> None: state = self._get_state(provider) state.tokens_used_today = 0 state.tokens_used_this_minute = 0 + state.daily_reset_ts = _next_midnight_utc(time.time()) diff --git a/freerelay/providers/registry.py b/freerelay/providers/registry.py index 136e11a..18eb448 100644 --- a/freerelay/providers/registry.py +++ b/freerelay/providers/registry.py @@ -29,7 +29,10 @@ class ProviderRegistry: def __init__(self, plugin_dir: Path | None = None) -> None: self.plugin_dir = plugin_dir or Path.home() / ".freerelay" / "plugins" self._providers: dict[str, type[BaseProvider]] = {} - self._loaded_modules: dict[str, str] = {} + # Map each module_name -> set of provider names that module registered. + # This lets reload_plugins() look up the actual provider keys (not file + # paths) so _providers.pop() removes the right entries. + self._loaded_modules: dict[str, set[str]] = {} def register(self, provider_cls: type[BaseProvider]) -> None: """Register a provider class.""" @@ -64,6 +67,7 @@ def discover_plugins(self) -> int: if module_name in self._loaded_modules: continue + registered_names: set[str] = set() try: spec = importlib.util.spec_from_file_location(module_name, py_file) if spec and spec.loader: @@ -80,17 +84,24 @@ def discover_plugins(self) -> int: and attr is not BaseProvider ): self.register(attr) + registered_names.add(attr.name) # type: ignore[attr-defined] loaded += 1 logger.info( "Loaded plugin provider: %s from %s", - attr.name, + attr.name, # type: ignore[attr-defined] py_file.name, ) - self._loaded_modules[module_name] = str(py_file) + self._loaded_modules[module_name] = registered_names - except Exception as e: + except ImportError as e: logger.error("Failed to load plugin %s: %s", py_file.name, e) + except (SyntaxError, AttributeError) as e: + # SyntaxError: the plugin file itself failed to compile + # AttributeError: BaseProvider subclass is missing required attrs + # Both indicate a broken plugin and should be surfaced loudly + # rather than silently logged as a generic "Exception" + logger.error("Invalid plugin %s: %s", py_file.name, e) return loaded @@ -101,12 +112,13 @@ def reload_plugins(self) -> int: Returns: Number of providers after reload. """ - # Remove previously loaded plugin modules - for module_name in list(self._loaded_modules.keys()): + # Remove previously loaded plugin modules. + # Look up provider names by module_name (not by file path) so + # _providers.pop() actually removes the right entries. + for module_name, provider_names in list(self._loaded_modules.items()): sys.modules.pop(module_name, None) - # Also remove plugin-registered providers - provider_name = self._loaded_modules[module_name] - self._providers.pop(provider_name, None) + for provider_name in provider_names: + self._providers.pop(provider_name, None) self._loaded_modules.clear() # Re-discover