diff --git a/scripts/smoke_test_proxy.py b/scripts/smoke_test_proxy.py index 6863316..83320ce 100644 --- a/scripts/smoke_test_proxy.py +++ b/scripts/smoke_test_proxy.py @@ -181,6 +181,17 @@ async def test_openai() -> None: health = await client.get("http://127.0.0.1:18081/health") assert health.status_code == 200, f"health: {health.status_code}" + # /v1/models reports the client's real identity, not a hardcoded + # stub. External llama.cpp mode with no --model → the "default" + # placeholder (forge genuinely doesn't know the served name). + models = await client.get("http://127.0.0.1:18081/v1/models") + assert models.status_code == 200, f"models: {models.status_code}" + mdata = models.json() + assert mdata["object"] == "list", mdata + assert mdata["data"][0]["id"] == "default", mdata["data"] + assert mdata["data"][0]["id"] != "forge", "regressed to hardcoded stub" + print("[ok] /v1/models reports client identity (default placeholder)") + resp = await client.post( "http://127.0.0.1:18081/v1/chat/completions", json={ @@ -366,6 +377,14 @@ async def test_path1_anthropic_passthrough() -> None: try: async with httpx.AsyncClient(timeout=10.0) as client: + # /v1/models surfaces the configured model identity (not "forge"). + # This proxy was constructed with model="claude-mock". + models = await client.get("http://127.0.0.1:18085/v1/models") + assert models.status_code == 200, f"models: {models.status_code}" + mdata = models.json() + assert mdata["data"][0]["id"] == "claude-mock", mdata["data"] + print("[ok] /v1/models reports configured model 'claude-mock'") + cache_marker = {"type": "ephemeral"} resp = await client.post( "http://127.0.0.1:18085/v1/messages", diff --git a/src/forge/clients/base.py b/src/forge/clients/base.py index 2a500ca..250c440 100644 --- a/src/forge/clients/base.py +++ b/src/forge/clients/base.py @@ -7,7 +7,7 @@ from enum import Enum from typing import Any, Protocol, runtime_checkable -from forge.core.workflow import LLMResponse, ToolCall, TextResponse, ToolSpec +from forge.core.workflow import LLMResponse, ToolSpec # Verbatim OpenAI-shape payloads forwarded by the proxy. The proxy hands the # client the user's original ``tools`` array so the backend sees the exact @@ -85,6 +85,11 @@ class LLMClient(Protocol): api_format: str """Wire format for Message.to_api_dict(): 'ollama' or 'openai'.""" + model: str + """The backend model identity, sent verbatim as the wire "model" field + (the served-model-name, gguf stem, or model tag depending on backend). + Distinct from any sampling-registry lookup key a client also derives.""" + async def send( self, messages: list[dict[str, str]], diff --git a/src/forge/clients/llamafile.py b/src/forge/clients/llamafile.py index 9fd4ab7..36ed6ed 100644 --- a/src/forge/clients/llamafile.py +++ b/src/forge/clients/llamafile.py @@ -168,15 +168,17 @@ def __init__( "backends)." ) self.base_url = base_url - # gguf_path is the canonical identity. self.model is the stem (no - # .gguf / .llamafile suffix) — used for the wire-format model field - # (llama-server ignores it but it flows into eval JSONL rows) and - # for sampling-defaults lookup. + # gguf_path is the source path. self.model is the stem (no + # .gguf / .llamafile suffix) used as the wire "model" field + # (llama-server ignores it but it flows into eval JSONL rows). + # sampling_key is the registry-lookup key; for llamafile it equals + # the stem, so the wire id and the lookup key are the same string. self.gguf_path = Path(gguf_path) self.model = _SHARD_SUFFIX_RE.sub("", self.gguf_path.stem) + self.sampling_key = self.model # Apply per-model recommended sampling defaults. Caller's explicit # (non-None) kwargs win over the map field-by-field. - defaults = apply_sampling_defaults(self.model, strict=recommended_sampling) + defaults = apply_sampling_defaults(self.sampling_key, strict=recommended_sampling) self.temperature = temperature if temperature is not None else defaults.get("temperature") self.top_p = top_p if top_p is not None else defaults.get("top_p") self.top_k = top_k if top_k is not None else defaults.get("top_k") diff --git a/src/forge/clients/ollama.py b/src/forge/clients/ollama.py index 5a9cce8..29c44ca 100644 --- a/src/forge/clients/ollama.py +++ b/src/forge/clients/ollama.py @@ -57,9 +57,12 @@ def __init__( ) -> None: self.base_url = base_url self.model = model + # sampling_key is the registry-lookup key. For Ollama the wire "model" + # field and the lookup key are the same string (the model tag). + self.sampling_key = self.model # Apply per-model recommended sampling defaults. Caller's explicit # (non-None) kwargs win over the map field-by-field. - defaults = apply_sampling_defaults(model, strict=recommended_sampling) + defaults = apply_sampling_defaults(self.sampling_key, strict=recommended_sampling) self.temperature = temperature if temperature is not None else defaults.get("temperature") self.top_p = top_p if top_p is not None else defaults.get("top_p") self.top_k = top_k if top_k is not None else defaults.get("top_k") diff --git a/src/forge/clients/openai_compat.py b/src/forge/clients/openai_compat.py index e6b0319..f89b6fb 100644 --- a/src/forge/clients/openai_compat.py +++ b/src/forge/clients/openai_compat.py @@ -59,6 +59,9 @@ def __init__( ) -> None: self.base_url = base_url.rstrip("/") self.model = model + # sampling_key is the registry-lookup key. For OpenAI-compat backends + # the wire "model" field and the lookup key are the same string. + self.sampling_key = self.model # Apply per-model recommended sampling defaults. Caller's explicit # (non-None) kwargs win over the map field-by-field. With @@ -66,7 +69,7 @@ def __init__( # apply_sampling_defaults returns an empty dict silently — which # is the common case for hosted providers whose model identifiers # aren't in forge's registry. - defaults = apply_sampling_defaults(self.model, strict=recommended_sampling) + defaults = apply_sampling_defaults(self.sampling_key, strict=recommended_sampling) self.temperature = temperature if temperature is not None else defaults.get("temperature") self.top_p = top_p if top_p is not None else defaults.get("top_p") self.top_k = top_k if top_k is not None else defaults.get("top_k") diff --git a/src/forge/clients/vllm.py b/src/forge/clients/vllm.py index ff50e12..d2867fb 100644 --- a/src/forge/clients/vllm.py +++ b/src/forge/clients/vllm.py @@ -60,16 +60,18 @@ def __init__( recommended_sampling: bool = False, ) -> None: self.base_url = base_url - # model_path is the canonical identity, sent verbatim in the wire - # "model" field. self.model is the derived registry-lookup key. Both - # are set together so the (model_path, model) invariant holds — see - # _set_model_identity. Must run before apply_sampling_defaults below, - # which reads self.model. + # Two identity roles, set together (see _set_model_identity): + # self.model — the wire "model" field, sent verbatim. For vLLM + # this is the model path / HF repo id (or the + # served-model-name once discovered in external + # mode), which vLLM validates the request against. + # self.sampling_key — the derived registry-lookup key for + # apply_sampling_defaults below (must be set first). self._set_model_identity(model_path) # Apply per-model recommended sampling defaults. Caller's explicit # (non-None) kwargs win over the map field-by-field. - defaults = apply_sampling_defaults(self.model, strict=recommended_sampling) + defaults = apply_sampling_defaults(self.sampling_key, strict=recommended_sampling) self.temperature = temperature if temperature is not None else defaults.get("temperature") self.top_p = top_p if top_p is not None else defaults.get("top_p") self.top_k = top_k if top_k is not None else defaults.get("top_k") @@ -93,33 +95,34 @@ async def aclose(self) -> None: await self._http.aclose() @staticmethod - def _derive_model_field(model_path: str) -> str: - """Derive the sampling-registry lookup key from the canonical path. + def _derive_sampling_key(wire_id: str) -> str: + """Derive the sampling-registry lookup key from the wire model id. - vLLM accepts either a local directory (safetensors + config) or an HF - repo id (e.g. "google/gemma-4-26B-A4B-it"). The lookup key uses the - path stem so registry lookups match the existing GGUF-stem convention: + vLLM's wire id is either a local directory (safetensors + config) or an + HF repo id (e.g. "google/gemma-4-26B-A4B-it"). The lookup key uses the + stem so registry lookups match the existing GGUF-stem convention: a filesystem path → its directory name; an HF repo id (has "/") → its trailing segment; anything else → the string unchanged. """ - path_obj = Path(model_path) + path_obj = Path(wire_id) if path_obj.is_absolute() or path_obj.exists(): return path_obj.name - if "/" in model_path: - return model_path.split("/")[-1] - return model_path - - def _set_model_identity(self, model_path: str | Path) -> None: - """Set both identity fields atomically from one canonical path. - - ``model_path`` is the wire "model" field (sent verbatim); ``model`` is - the derived registry key. Used by ``__init__`` and by the proxy's - external-mode served-name adoption, so the ``(model_path, model)`` - invariant holds the same way in both — instead of mutating the two - fields separately after served-name discovery. + if "/" in wire_id: + return wire_id.split("/")[-1] + return wire_id + + def _set_model_identity(self, wire_id: str | Path) -> None: + """Set both identity fields atomically from one wire id. + + ``model`` is the wire "model" field (sent verbatim); ``sampling_key`` + is the derived registry-lookup key. Used by ``__init__`` and by the + proxy's external-mode served-name adoption, so the + ``(model, sampling_key)`` invariant holds the same way in both — + instead of mutating the two fields separately after served-name + discovery. """ - self.model_path = str(model_path) - self.model = self._derive_model_field(self.model_path) + self.model = str(wire_id) + self.sampling_key = self._derive_sampling_key(self.model) # Sampling fields recognized in per-call overrides. ``seed`` is # accepted only as a per-call override (not an instance field). @@ -194,7 +197,7 @@ async def send( reasoning server-side and is native-only. """ body: dict[str, Any] = { - "model": self.model_path, + "model": self.model, "messages": messages, "stream": False, } @@ -245,7 +248,7 @@ async def send_stream( accepted for protocol symmetry and ignored (see ``send``). """ body: dict[str, Any] = { - "model": self.model_path, + "model": self.model, "messages": messages, "stream": True, "stream_options": {"include_usage": True}, diff --git a/src/forge/proxy/server.py b/src/forge/proxy/server.py index 3ca3149..3312ee6 100644 --- a/src/forge/proxy/server.py +++ b/src/forge/proxy/server.py @@ -203,10 +203,10 @@ async def _handle_health(self, writer: asyncio.StreamWriter) -> None: await self._send_json(writer, 200, body) async def _handle_models(self, writer: asyncio.StreamWriter) -> None: - """GET /v1/models — returns a minimal model list.""" + """GET /v1/models — report the backend model the proxy is fronting.""" body = json.dumps({ "object": "list", - "data": [{"id": "forge", "object": "model"}], + "data": [{"id": self._client.model, "object": "model"}], }) await self._send_json(writer, 200, body) diff --git a/tests/unit/test_proxy_proxy.py b/tests/unit/test_proxy_proxy.py index 871ac45..f41a662 100644 --- a/tests/unit/test_proxy_proxy.py +++ b/tests/unit/test_proxy_proxy.py @@ -142,14 +142,14 @@ async def test_vllm_adopts_served_model_name(self) -> None: new_callable=AsyncMock, return_value="my-awq-model", ): client, _ = await proxy._setup_external() - assert client.model_path == "my-awq-model" assert client.model == "my-awq-model" + assert client.sampling_key == "my-awq-model" @pytest.mark.asyncio async def test_vllm_served_repo_id_keeps_wire_path_derives_registry_key(self) -> None: # An HF-repo-id served name must reach the wire verbatim (vLLM validates - # it), while the registry key is the derived stem — the (model_path, - # model) invariant, applied to served-name adoption. + # it), while the registry key is the derived stem — the (model, + # sampling_key) invariant, applied to served-name adoption. proxy = ProxyServer( backend_url="http://localhost:8000", backend="vllm", budget_tokens=8192, ) @@ -158,8 +158,8 @@ async def test_vllm_served_repo_id_keeps_wire_path_derives_registry_key(self) -> new_callable=AsyncMock, return_value="google/gemma-4-26B-A4B-it", ): client, _ = await proxy._setup_external() - assert client.model_path == "google/gemma-4-26B-A4B-it" - assert client.model == "gemma-4-26B-A4B-it" + assert client.model == "google/gemma-4-26B-A4B-it" + assert client.sampling_key == "gemma-4-26B-A4B-it" @pytest.mark.asyncio async def test_vllm_keeps_placeholder_when_discovery_fails(self) -> None: @@ -170,7 +170,7 @@ async def test_vllm_keeps_placeholder_when_discovery_fails(self) -> None: VLLMClient, "get_served_model_name", new_callable=AsyncMock, return_value=None, ): client, _ = await proxy._setup_external() - assert client.model_path == "default" + assert client.model == "default" @pytest.mark.asyncio async def test_url_v1_suffix_preserved(self) -> None: diff --git a/tests/unit/test_proxy_server.py b/tests/unit/test_proxy_server.py index af37ac2..17443e9 100644 --- a/tests/unit/test_proxy_server.py +++ b/tests/unit/test_proxy_server.py @@ -19,6 +19,7 @@ def _mock_client(response): """Create a mock LLMClient that returns the given response.""" client = AsyncMock() client.api_format = "ollama" + client.model = "mock-model" client.send = AsyncMock(return_value=response) return client @@ -143,7 +144,7 @@ async def test_models_endpoint(self, server_factory): assert status == 200 data = json.loads(body) assert data["object"] == "list" - assert len(data["data"]) > 0 + assert data["data"][0]["id"] == "mock-model" @pytest.mark.asyncio async def test_not_found(self, server_factory): diff --git a/tests/unit/test_vllm_client.py b/tests/unit/test_vllm_client.py index 0de394b..bda18a5 100644 --- a/tests/unit/test_vllm_client.py +++ b/tests/unit/test_vllm_client.py @@ -90,19 +90,21 @@ def _text_response(content: str = "hi", reasoning: str | None = None) -> dict: class TestConstructor: - def test_directory_path_derives_model_from_dirname(self) -> None: + def test_directory_path_derives_sampling_key_from_dirname(self) -> None: c = VLLMClient(model_path="/models/gemma-4-26B-A4B-it-AWQ-4bit") - assert c.model == "gemma-4-26B-A4B-it-AWQ-4bit" - assert c.model_path == "/models/gemma-4-26B-A4B-it-AWQ-4bit" + # model is the wire id (the path verbatim); sampling_key is the stem. + assert c.model == "/models/gemma-4-26B-A4B-it-AWQ-4bit" + assert c.sampling_key == "gemma-4-26B-A4B-it-AWQ-4bit" - def test_hf_repo_id_derives_model_from_trailing_segment(self) -> None: + def test_hf_repo_id_derives_sampling_key_from_trailing_segment(self) -> None: c = VLLMClient(model_path="google/gemma-4-26B-A4B-it") - assert c.model == "gemma-4-26B-A4B-it" - assert c.model_path == "google/gemma-4-26B-A4B-it" + assert c.model == "google/gemma-4-26B-A4B-it" + assert c.sampling_key == "gemma-4-26B-A4B-it" def test_single_token_model_path(self) -> None: c = VLLMClient(model_path="some-local-name") assert c.model == "some-local-name" + assert c.sampling_key == "some-local-name" def test_api_format_is_openai(self) -> None: c = VLLMClient(model_path="/models/x")