From 3ef0e44a17a153dfffdaac6c58e49e919e955b56 Mon Sep 17 00:00:00 2001 From: PALAK JAISWAL Date: Tue, 26 May 2026 14:28:23 +0530 Subject: [PATCH 1/2] Implement fallback LLM provider logic --- backend/app/ai/providers/__init__.py | 226 +++++++++++++++--- backend/app/ai/service.py | 4 +- backend/app/config.py | 1 + .../tests/unit/ai/test_fallback_provider.py | 95 ++++++++ 4 files changed, 293 insertions(+), 33 deletions(-) create mode 100644 backend/tests/unit/ai/test_fallback_provider.py diff --git a/backend/app/ai/providers/__init__.py b/backend/app/ai/providers/__init__.py index 63c1b86f..378fc868 100644 --- a/backend/app/ai/providers/__init__.py +++ b/backend/app/ai/providers/__init__.py @@ -1,61 +1,223 @@ -"""Provider factory — returns the configured LLM provider instance.""" +"""Provider factory and fallback orchestration for LLM providers.""" +import logging +from collections.abc import AsyncIterator, Iterable +from typing import TypeVar + +from pydantic import BaseModel + from app.ai.providers.base import LLMProvider, LLMProviderError +T = TypeVar("T", bound=BaseModel) +logger = logging.getLogger(__name__) -def get_provider() -> LLMProvider: - """ - Instantiate and return the LLM provider configured in settings. +SUPPORTED_PROVIDERS = ("mock", "openrouter", "openai", "ollama") +DEFAULT_FALLBACK_ORDER = ("openrouter", "openai", "ollama") + + +class FallbackProvider(LLMProvider): + """Try configured LLM providers in order until one succeeds.""" - The provider is determined by ``ENVFORGE_LLM_PROVIDER`` env var: - - ``mock`` → deterministic responses for testing - - ``openrouter`` → routes to 100+ models via OpenRouter API - - ``openai`` → direct OpenAI API - - ``ollama`` → local inference (air gapped, implemented) + def __init__(self, providers: Iterable[LLMProvider]) -> None: + self.providers = list(providers) + if not self.providers: + raise LLMProviderError("fallback", "No LLM providers are configured.") + self.active_provider = self.providers[0] - Returns: - An instance of a class implementing :class:`LLMProvider`. + @property + def provider_name(self) -> str: + return type(self.active_provider).__name__ - Raises: - LLMProviderError: If the configured provider is unknown or misconfigured. + @property + def model(self) -> str: + return getattr(self.active_provider, "model", "unknown") + + @property + def last_token_usage(self) -> dict[str, int] | None: + token_usage = getattr(self.active_provider, "last_token_usage", None) + if callable(token_usage): + return token_usage() + if isinstance(token_usage, dict): + return token_usage + fallback_usage = getattr(self.active_provider, "_last_usage", None) + return fallback_usage if isinstance(fallback_usage, dict) else None + + async def complete( + self, + system_prompt: str, + user_message: str, + response_model: type[T], + ) -> T: + failures: list[str] = [] + + for provider in self.providers: + self.active_provider = provider + provider_name = type(provider).__name__ + try: + result = await provider.complete( + system_prompt=system_prompt, + user_message=user_message, + response_model=response_model, + ) + if failures: + logger.info("LLM fallback succeeded with %s", provider_name) + return result + except LLMProviderError as exc: + failures.append(str(exc)) + logger.warning("LLM provider %s failed: %s", provider_name, exc.reason) + + raise LLMProviderError( + "fallback", + "All LLM providers failed. Attempts: " + " | ".join(failures), + ) + + async def stream( + self, + system_prompt: str, + user_message: str, + response_model: type[T], + ) -> AsyncIterator[str]: + failures: list[str] = [] + + for provider in self.providers: + self.active_provider = provider + provider_name = type(provider).__name__ + yielded = False + try: + async for chunk in provider.stream( + system_prompt=system_prompt, + user_message=user_message, + response_model=response_model, + ): + yielded = True + yield chunk + return + except LLMProviderError as exc: + if yielded: + raise + failures.append(str(exc)) + logger.warning( + "LLM stream provider %s failed before yielding: %s", + provider_name, + exc.reason, + ) + + raise LLMProviderError( + "fallback", + "All LLM stream providers failed. Attempts: " + " | ".join(failures), + ) + + +def get_provider() -> LLMProvider: + """ + Instantiate the configured LLM provider with fallback providers. + + The primary provider is determined by ``ENVFORGE_LLM_PROVIDER``. For real + hosted/local providers, fallback order defaults to the remaining providers + in ``openrouter -> openai -> ollama`` order and can be overridden with the + comma-separated ``ENVFORGE_LLM_PROVIDER_FALLBACKS`` env var. """ from app.config import get_settings + settings = get_settings() + provider_names = _provider_chain(settings) + providers: list[LLMProvider] = [] - provider_name = settings.envforge_llm_provider.lower() + for index, provider_name in enumerate(provider_names): + try: + providers.append(_build_provider(provider_name, settings)) + except LLMProviderError: + if index == 0: + raise + logger.warning("Skipping unavailable fallback LLM provider: %s", provider_name) + if len(providers) == 1: + return providers[0] + return FallbackProvider(providers) + + +def _provider_chain(settings: object) -> list[str]: + primary = str(getattr(settings, "envforge_llm_provider")).lower() + fallback_setting = getattr(settings, "envforge_llm_provider_fallbacks", "") + + if primary not in SUPPORTED_PROVIDERS: + raise LLMProviderError( + primary, + f"Unknown LLM provider: '{primary}'. " + f"Valid options: {', '.join(SUPPORTED_PROVIDERS)}.", + ) + + if primary == "mock": + return ["mock"] + + configured_fallbacks = _parse_fallbacks(fallback_setting) + if configured_fallbacks: + chain = [primary, *configured_fallbacks] + else: + primary_index = DEFAULT_FALLBACK_ORDER.index(primary) + chain = list(DEFAULT_FALLBACK_ORDER[primary_index:]) + + return _dedupe_provider_names(chain) + + +def _parse_fallbacks(value: object) -> list[str]: + if not value: + return [] + if isinstance(value, str): + names = [name.strip().lower() for name in value.split(",")] + else: + names = [str(name).strip().lower() for name in value] + + invalid = [name for name in names if name and name not in SUPPORTED_PROVIDERS] + if invalid: + raise LLMProviderError( + "fallback", + f"Unknown fallback LLM provider(s): {', '.join(invalid)}. " + f"Valid options: {', '.join(SUPPORTED_PROVIDERS)}.", + ) + return [name for name in names if name] + + +def _dedupe_provider_names(provider_names: Iterable[str]) -> list[str]: + deduped: list[str] = [] + for provider_name in provider_names: + if provider_name not in deduped: + deduped.append(provider_name) + return deduped + + +def _build_provider(provider_name: str, settings: object) -> LLMProvider: if provider_name == "mock": from app.ai.providers.mock import MockProvider + return MockProvider() if provider_name == "openrouter": from app.ai.providers.openrouter import OpenRouterProvider + return OpenRouterProvider( - api_key=settings.openrouter_api_key, - model=settings.openrouter_model, - max_tokens=settings.ai_max_tokens, - temperature=settings.ai_temperature, + api_key=getattr(settings, "openrouter_api_key"), + model=getattr(settings, "openrouter_model"), + max_tokens=getattr(settings, "ai_max_tokens"), + temperature=getattr(settings, "ai_temperature"), ) if provider_name == "openai": from app.ai.providers.openai import OpenAIProvider - # Safely extract dynamic configuration values from environment context settings - api_key = settings.openai_api_key - base_url = getattr(settings, "openai_base_url", "https://api.openai.com/v1") return OpenAIProvider( - api_key=api_key, - base_url=base_url + api_key=getattr(settings, "openai_api_key"), + base_url=getattr(settings, "openai_base_url", "https://api.openai.com/v1"), + model=getattr(settings, "openai_model"), + max_tokens=getattr(settings, "ai_max_tokens"), + temperature=getattr(settings, "ai_temperature"), ) if provider_name == "ollama": from app.ai.providers.ollama import OllamaProvider + return OllamaProvider( - base_url=settings.ollama_base_url, - model=settings.ollama_model, - ) - - raise LLMProviderError( - provider_name, - f"Unknown LLM provider: '{provider_name}'. " - f"Valid options: mock, openrouter, openai, ollama.", - ) + base_url=getattr(settings, "ollama_base_url"), + model=getattr(settings, "ollama_model"), + ) + + raise LLMProviderError(provider_name, f"Unknown LLM provider: '{provider_name}'.") diff --git a/backend/app/ai/service.py b/backend/app/ai/service.py index e35354c8..e6eb7b77 100644 --- a/backend/app/ai/service.py +++ b/backend/app/ai/service.py @@ -81,7 +81,7 @@ async def troubleshoot( # ── Step 2: Call LLM ────────────────────────────────────────────── provider = get_provider() - provider_name = type(provider).__name__ + provider_name = getattr(provider, "provider_name", type(provider).__name__) model_name = getattr(provider, "model", "unknown") try: @@ -92,6 +92,8 @@ async def troubleshoot( user_message=user_message, response_model=TroubleshootResponse, ) + provider_name = getattr(provider, "provider_name", type(provider).__name__) + model_name = getattr(provider, "model", "unknown") except LLMProviderError as exc: # Log the failed attempt latency_ms = int((time.monotonic() - start_time) * 1000) diff --git a/backend/app/config.py b/backend/app/config.py index 20d36c27..d77791e7 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -50,6 +50,7 @@ def allowed_origins_list(self) -> list[str]: # ── AI / LLM ───────────────────────────────────────────── envforge_llm_provider: Literal["openai", "openrouter", "ollama", "mock"] = "mock" + envforge_llm_provider_fallbacks: str = "" openai_api_key: str = "" openai_model: str = "gpt-4o" openrouter_api_key: str = "" diff --git a/backend/tests/unit/ai/test_fallback_provider.py b/backend/tests/unit/ai/test_fallback_provider.py new file mode 100644 index 00000000..7c638c4f --- /dev/null +++ b/backend/tests/unit/ai/test_fallback_provider.py @@ -0,0 +1,95 @@ +from collections.abc import AsyncIterator +from typing import TypeVar + +import pytest +from pydantic import BaseModel + +from app.ai.providers import FallbackProvider, _provider_chain +from app.ai.providers.base import LLMProvider, LLMProviderError + +T = TypeVar("T", bound=BaseModel) + + +class DummyResponse(BaseModel): + message: str + + +class DummyProvider(LLMProvider): + def __init__(self, name: str, *, should_fail: bool = False) -> None: + self.name = name + self.model = f"{name}-model" + self.should_fail = should_fail + self.calls = 0 + self._last_usage = {"total_tokens": 7} + + async def complete( + self, + system_prompt: str, + user_message: str, + response_model: type[T], + ) -> T: + self.calls += 1 + if self.should_fail: + raise LLMProviderError(self.name, "temporary failure") + return response_model(message=f"ok from {self.name}") + + async def stream( + self, + system_prompt: str, + user_message: str, + response_model: type[T], + ) -> AsyncIterator[str]: + self.calls += 1 + if self.should_fail: + raise LLMProviderError(self.name, "temporary failure") + yield f"ok from {self.name}" + + +class SettingsStub: + envforge_llm_provider = "openrouter" + envforge_llm_provider_fallbacks = "" + + +@pytest.mark.asyncio +async def test_fallback_provider_tries_next_provider_after_error(): + primary = DummyProvider("primary", should_fail=True) + fallback = DummyProvider("fallback") + provider = FallbackProvider([primary, fallback]) + + result = await provider.complete("system", "user", DummyResponse) + + assert result.message == "ok from fallback" + assert primary.calls == 1 + assert fallback.calls == 1 + assert provider.provider_name == "DummyProvider" + assert provider.model == "fallback-model" + assert provider.last_token_usage == {"total_tokens": 7} + + +@pytest.mark.asyncio +async def test_fallback_provider_raises_after_all_providers_fail(): + provider = FallbackProvider([ + DummyProvider("primary", should_fail=True), + DummyProvider("fallback", should_fail=True), + ]) + + with pytest.raises(LLMProviderError) as exc_info: + await provider.complete("system", "user", DummyResponse) + + assert exc_info.value.provider == "fallback" + assert "All LLM providers failed" in exc_info.value.reason + assert "[primary] temporary failure" in exc_info.value.reason + assert "[fallback] temporary failure" in exc_info.value.reason + + +def test_provider_chain_defaults_to_remaining_real_providers(): + settings = SettingsStub() + + assert _provider_chain(settings) == ["openrouter", "openai", "ollama"] + + +def test_provider_chain_allows_configured_fallback_order(): + settings = SettingsStub() + settings.envforge_llm_provider_fallbacks = "ollama, openai, openrouter" + + assert _provider_chain(settings) == ["openrouter", "ollama", "openai"] From df560794afbb0d631fb5d4d62f77f23599457db1 Mon Sep 17 00:00:00 2001 From: PALAK JAISWAL Date: Tue, 26 May 2026 14:37:01 +0530 Subject: [PATCH 2/2] Add Trivy scan before Docker image push --- .github/workflows/release.yml | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index e2a4fa2b..643a5485 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -55,16 +55,30 @@ jobs: username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} - - name: Build and push Backend Docker image + - name: Build Backend Docker image uses: docker/build-push-action@v5 with: context: ./backend file: ./backend/Dockerfile - push: true + load: true tags: | ghcr.io/${{ env.REPO_LC }}/backend:latest ghcr.io/${{ env.REPO_LC }}/backend:${{ github.ref_name }} + - name: Scan Backend Docker image with Trivy + uses: aquasecurity/trivy-action@v0.36.0 + with: + image-ref: ghcr.io/${{ env.REPO_LC }}/backend:${{ github.ref_name }} + scan-type: image + vuln-type: os,library + severity: CRITICAL,HIGH + exit-code: "1" + + - name: Push Backend Docker image + run: | + docker push ghcr.io/${{ env.REPO_LC }}/backend:latest + docker push ghcr.io/${{ env.REPO_LC }}/backend:${{ github.ref_name }} + - name: Publish CLI package to PyPI uses: pypa/gh-action-pypi-publish@release/v1 with: