Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 194 additions & 32 deletions backend/app/ai/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,223 @@
"""Provider factory — returns the configured LLM provider instance."""
"""Provider factory and fallback orchestration for LLM providers."""
import logging
from collections.abc import AsyncIterator, Iterable
from typing import TypeVar

from pydantic import BaseModel

from app.ai.providers.base import LLMProvider, LLMProviderError

T = TypeVar("T", bound=BaseModel)
logger = logging.getLogger(__name__)

def get_provider() -> LLMProvider:
"""
Instantiate and return the LLM provider configured in settings.
SUPPORTED_PROVIDERS = ("mock", "openrouter", "openai", "ollama")
DEFAULT_FALLBACK_ORDER = ("openrouter", "openai", "ollama")


class FallbackProvider(LLMProvider):
"""Try configured LLM providers in order until one succeeds."""

The provider is determined by ``ENVFORGE_LLM_PROVIDER`` env var:
- ``mock`` → deterministic responses for testing
- ``openrouter`` → routes to 100+ models via OpenRouter API
- ``openai`` → direct OpenAI API
- ``ollama`` → local inference (air gapped, implemented)
def __init__(self, providers: Iterable[LLMProvider]) -> None:
self.providers = list(providers)
if not self.providers:
raise LLMProviderError("fallback", "No LLM providers are configured.")
self.active_provider = self.providers[0]

Returns:
An instance of a class implementing :class:`LLMProvider`.
@property
def provider_name(self) -> str:
return type(self.active_provider).__name__

Raises:
LLMProviderError: If the configured provider is unknown or misconfigured.
@property
def model(self) -> str:
return getattr(self.active_provider, "model", "unknown")

@property
def last_token_usage(self) -> dict[str, int] | None:
token_usage = getattr(self.active_provider, "last_token_usage", None)
if callable(token_usage):
return token_usage()
if isinstance(token_usage, dict):
return token_usage
fallback_usage = getattr(self.active_provider, "_last_usage", None)
return fallback_usage if isinstance(fallback_usage, dict) else None

async def complete(
self,
system_prompt: str,
user_message: str,
response_model: type[T],
) -> T:
failures: list[str] = []

for provider in self.providers:
self.active_provider = provider
provider_name = type(provider).__name__
try:
result = await provider.complete(
system_prompt=system_prompt,
user_message=user_message,
response_model=response_model,
)
if failures:
logger.info("LLM fallback succeeded with %s", provider_name)
return result
except LLMProviderError as exc:
failures.append(str(exc))
logger.warning("LLM provider %s failed: %s", provider_name, exc.reason)

raise LLMProviderError(
"fallback",
"All LLM providers failed. Attempts: " + " | ".join(failures),
)

async def stream(
self,
system_prompt: str,
user_message: str,
response_model: type[T],
) -> AsyncIterator[str]:
failures: list[str] = []

for provider in self.providers:
self.active_provider = provider
provider_name = type(provider).__name__
yielded = False
try:
async for chunk in provider.stream(
system_prompt=system_prompt,
user_message=user_message,
response_model=response_model,
):
yielded = True
yield chunk
return
except LLMProviderError as exc:
if yielded:
raise
failures.append(str(exc))
logger.warning(
"LLM stream provider %s failed before yielding: %s",
provider_name,
exc.reason,
)

raise LLMProviderError(
"fallback",
"All LLM stream providers failed. Attempts: " + " | ".join(failures),
)


def get_provider() -> LLMProvider:
"""
Instantiate the configured LLM provider with fallback providers.

The primary provider is determined by ``ENVFORGE_LLM_PROVIDER``. For real
hosted/local providers, fallback order defaults to the remaining providers
in ``openrouter -> openai -> ollama`` order and can be overridden with the
comma-separated ``ENVFORGE_LLM_PROVIDER_FALLBACKS`` env var.
"""
from app.config import get_settings

settings = get_settings()
provider_names = _provider_chain(settings)
providers: list[LLMProvider] = []

provider_name = settings.envforge_llm_provider.lower()
for index, provider_name in enumerate(provider_names):
try:
providers.append(_build_provider(provider_name, settings))
except LLMProviderError:
if index == 0:
raise
logger.warning("Skipping unavailable fallback LLM provider: %s", provider_name)

if len(providers) == 1:
return providers[0]
return FallbackProvider(providers)


def _provider_chain(settings: object) -> list[str]:
primary = str(getattr(settings, "envforge_llm_provider")).lower()
fallback_setting = getattr(settings, "envforge_llm_provider_fallbacks", "")

if primary not in SUPPORTED_PROVIDERS:
raise LLMProviderError(
primary,
f"Unknown LLM provider: '{primary}'. "
f"Valid options: {', '.join(SUPPORTED_PROVIDERS)}.",
)

if primary == "mock":
return ["mock"]

configured_fallbacks = _parse_fallbacks(fallback_setting)
if configured_fallbacks:
chain = [primary, *configured_fallbacks]
else:
primary_index = DEFAULT_FALLBACK_ORDER.index(primary)
chain = list(DEFAULT_FALLBACK_ORDER[primary_index:])

return _dedupe_provider_names(chain)


def _parse_fallbacks(value: object) -> list[str]:
if not value:
return []
if isinstance(value, str):
names = [name.strip().lower() for name in value.split(",")]
else:
names = [str(name).strip().lower() for name in value]

invalid = [name for name in names if name and name not in SUPPORTED_PROVIDERS]
if invalid:
raise LLMProviderError(
"fallback",
f"Unknown fallback LLM provider(s): {', '.join(invalid)}. "
f"Valid options: {', '.join(SUPPORTED_PROVIDERS)}.",
)
return [name for name in names if name]


def _dedupe_provider_names(provider_names: Iterable[str]) -> list[str]:
deduped: list[str] = []
for provider_name in provider_names:
if provider_name not in deduped:
deduped.append(provider_name)
return deduped


def _build_provider(provider_name: str, settings: object) -> LLMProvider:
if provider_name == "mock":
from app.ai.providers.mock import MockProvider

return MockProvider()

if provider_name == "openrouter":
from app.ai.providers.openrouter import OpenRouterProvider

return OpenRouterProvider(
api_key=settings.openrouter_api_key,
model=settings.openrouter_model,
max_tokens=settings.ai_max_tokens,
temperature=settings.ai_temperature,
api_key=getattr(settings, "openrouter_api_key"),
model=getattr(settings, "openrouter_model"),
max_tokens=getattr(settings, "ai_max_tokens"),
temperature=getattr(settings, "ai_temperature"),
)

if provider_name == "openai":
from app.ai.providers.openai import OpenAIProvider
# Safely extract dynamic configuration values from environment context settings
api_key = settings.openai_api_key
base_url = getattr(settings, "openai_base_url", "https://api.openai.com/v1")

return OpenAIProvider(
api_key=api_key,
base_url=base_url
api_key=getattr(settings, "openai_api_key"),
base_url=getattr(settings, "openai_base_url", "https://api.openai.com/v1"),
model=getattr(settings, "openai_model"),
max_tokens=getattr(settings, "ai_max_tokens"),
temperature=getattr(settings, "ai_temperature"),
)

if provider_name == "ollama":
from app.ai.providers.ollama import OllamaProvider

return OllamaProvider(
base_url=settings.ollama_base_url,
model=settings.ollama_model,
)

raise LLMProviderError(
provider_name,
f"Unknown LLM provider: '{provider_name}'. "
f"Valid options: mock, openrouter, openai, ollama.",
)
base_url=getattr(settings, "ollama_base_url"),
model=getattr(settings, "ollama_model"),
)

raise LLMProviderError(provider_name, f"Unknown LLM provider: '{provider_name}'.")
4 changes: 3 additions & 1 deletion backend/app/ai/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ async def troubleshoot(

# ── Step 2: Call LLM ──────────────────────────────────────────────
provider = get_provider()
provider_name = type(provider).__name__
provider_name = getattr(provider, "provider_name", type(provider).__name__)
model_name = getattr(provider, "model", "unknown")

try:
Expand All @@ -92,6 +92,8 @@ async def troubleshoot(
user_message=user_message,
response_model=TroubleshootResponse,
)
provider_name = getattr(provider, "provider_name", type(provider).__name__)
model_name = getattr(provider, "model", "unknown")
except LLMProviderError as exc:
# Log the failed attempt
latency_ms = int((time.monotonic() - start_time) * 1000)
Expand Down
1 change: 1 addition & 0 deletions backend/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def allowed_origins_list(self) -> list[str]:

# ── AI / LLM ─────────────────────────────────────────────
envforge_llm_provider: Literal["openai", "openrouter", "ollama", "mock"] = "mock"
envforge_llm_provider_fallbacks: str = ""
openai_api_key: str = ""
openai_model: str = "gpt-4o"
openrouter_api_key: str = ""
Expand Down
95 changes: 95 additions & 0 deletions backend/tests/unit/ai/test_fallback_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from collections.abc import AsyncIterator
from typing import TypeVar

import pytest
from pydantic import BaseModel

from app.ai.providers import FallbackProvider, _provider_chain
from app.ai.providers.base import LLMProvider, LLMProviderError

T = TypeVar("T", bound=BaseModel)


class DummyResponse(BaseModel):
message: str


class DummyProvider(LLMProvider):
def __init__(self, name: str, *, should_fail: bool = False) -> None:
self.name = name
self.model = f"{name}-model"
self.should_fail = should_fail
self.calls = 0
self._last_usage = {"total_tokens": 7}

async def complete(
self,
system_prompt: str,
user_message: str,
response_model: type[T],
) -> T:
self.calls += 1
if self.should_fail:
raise LLMProviderError(self.name, "temporary failure")
return response_model(message=f"ok from {self.name}")

async def stream(
self,
system_prompt: str,
user_message: str,
response_model: type[T],
) -> AsyncIterator[str]:
self.calls += 1
if self.should_fail:
raise LLMProviderError(self.name, "temporary failure")
yield f"ok from {self.name}"


class SettingsStub:
envforge_llm_provider = "openrouter"
envforge_llm_provider_fallbacks = ""


@pytest.mark.asyncio
async def test_fallback_provider_tries_next_provider_after_error():
primary = DummyProvider("primary", should_fail=True)
fallback = DummyProvider("fallback")
provider = FallbackProvider([primary, fallback])

result = await provider.complete("system", "user", DummyResponse)

assert result.message == "ok from fallback"
assert primary.calls == 1
assert fallback.calls == 1
assert provider.provider_name == "DummyProvider"
assert provider.model == "fallback-model"
assert provider.last_token_usage == {"total_tokens": 7}


@pytest.mark.asyncio
async def test_fallback_provider_raises_after_all_providers_fail():
provider = FallbackProvider([
DummyProvider("primary", should_fail=True),
DummyProvider("fallback", should_fail=True),
])

with pytest.raises(LLMProviderError) as exc_info:
await provider.complete("system", "user", DummyResponse)

assert exc_info.value.provider == "fallback"
assert "All LLM providers failed" in exc_info.value.reason
assert "[primary] temporary failure" in exc_info.value.reason
assert "[fallback] temporary failure" in exc_info.value.reason


def test_provider_chain_defaults_to_remaining_real_providers():
settings = SettingsStub()

assert _provider_chain(settings) == ["openrouter", "openai", "ollama"]


def test_provider_chain_allows_configured_fallback_order():
settings = SettingsStub()
settings.envforge_llm_provider_fallbacks = "ollama, openai, openrouter"

assert _provider_chain(settings) == ["openrouter", "ollama", "openai"]
Loading