Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions scripts/smoke_test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,17 @@ async def test_openai() -> None:
health = await client.get("http://127.0.0.1:18081/health")
assert health.status_code == 200, f"health: {health.status_code}"

# /v1/models reports the client's real identity, not a hardcoded
# stub. External llama.cpp mode with no --model → the "default"
# placeholder (forge genuinely doesn't know the served name).
models = await client.get("http://127.0.0.1:18081/v1/models")
assert models.status_code == 200, f"models: {models.status_code}"
mdata = models.json()
assert mdata["object"] == "list", mdata
assert mdata["data"][0]["id"] == "default", mdata["data"]
assert mdata["data"][0]["id"] != "forge", "regressed to hardcoded stub"
print("[ok] /v1/models reports client identity (default placeholder)")

resp = await client.post(
"http://127.0.0.1:18081/v1/chat/completions",
json={
Expand Down Expand Up @@ -366,6 +377,14 @@ async def test_path1_anthropic_passthrough() -> None:

try:
async with httpx.AsyncClient(timeout=10.0) as client:
# /v1/models surfaces the configured model identity (not "forge").
# This proxy was constructed with model="claude-mock".
models = await client.get("http://127.0.0.1:18085/v1/models")
assert models.status_code == 200, f"models: {models.status_code}"
mdata = models.json()
assert mdata["data"][0]["id"] == "claude-mock", mdata["data"]
print("[ok] /v1/models reports configured model 'claude-mock'")

cache_marker = {"type": "ephemeral"}
resp = await client.post(
"http://127.0.0.1:18085/v1/messages",
Expand Down
7 changes: 6 additions & 1 deletion src/forge/clients/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from enum import Enum
from typing import Any, Protocol, runtime_checkable

from forge.core.workflow import LLMResponse, ToolCall, TextResponse, ToolSpec
from forge.core.workflow import LLMResponse, ToolSpec

# Verbatim OpenAI-shape payloads forwarded by the proxy. The proxy hands the
# client the user's original ``tools`` array so the backend sees the exact
Expand Down Expand Up @@ -85,6 +85,11 @@ class LLMClient(Protocol):
api_format: str
"""Wire format for Message.to_api_dict(): 'ollama' or 'openai'."""

model: str
"""The backend model identity, sent verbatim as the wire "model" field
(the served-model-name, gguf stem, or model tag depending on backend).
Distinct from any sampling-registry lookup key a client also derives."""

async def send(
self,
messages: list[dict[str, str]],
Expand Down
12 changes: 7 additions & 5 deletions src/forge/clients/llamafile.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,15 +168,17 @@ def __init__(
"backends)."
)
self.base_url = base_url
# gguf_path is the canonical identity. self.model is the stem (no
# .gguf / .llamafile suffix) — used for the wire-format model field
# (llama-server ignores it but it flows into eval JSONL rows) and
# for sampling-defaults lookup.
# gguf_path is the source path. self.model is the stem (no
# .gguf / .llamafile suffix) used as the wire "model" field
# (llama-server ignores it but it flows into eval JSONL rows).
# sampling_key is the registry-lookup key; for llamafile it equals
# the stem, so the wire id and the lookup key are the same string.
self.gguf_path = Path(gguf_path)
self.model = _SHARD_SUFFIX_RE.sub("", self.gguf_path.stem)
self.sampling_key = self.model
# Apply per-model recommended sampling defaults. Caller's explicit
# (non-None) kwargs win over the map field-by-field.
defaults = apply_sampling_defaults(self.model, strict=recommended_sampling)
defaults = apply_sampling_defaults(self.sampling_key, strict=recommended_sampling)
self.temperature = temperature if temperature is not None else defaults.get("temperature")
self.top_p = top_p if top_p is not None else defaults.get("top_p")
self.top_k = top_k if top_k is not None else defaults.get("top_k")
Expand Down
5 changes: 4 additions & 1 deletion src/forge/clients/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,12 @@ def __init__(
) -> None:
self.base_url = base_url
self.model = model
# sampling_key is the registry-lookup key. For Ollama the wire "model"
# field and the lookup key are the same string (the model tag).
self.sampling_key = self.model
# Apply per-model recommended sampling defaults. Caller's explicit
# (non-None) kwargs win over the map field-by-field.
defaults = apply_sampling_defaults(model, strict=recommended_sampling)
defaults = apply_sampling_defaults(self.sampling_key, strict=recommended_sampling)
self.temperature = temperature if temperature is not None else defaults.get("temperature")
self.top_p = top_p if top_p is not None else defaults.get("top_p")
self.top_k = top_k if top_k is not None else defaults.get("top_k")
Expand Down
5 changes: 4 additions & 1 deletion src/forge/clients/openai_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,17 @@ def __init__(
) -> None:
self.base_url = base_url.rstrip("/")
self.model = model
# sampling_key is the registry-lookup key. For OpenAI-compat backends
# the wire "model" field and the lookup key are the same string.
self.sampling_key = self.model

# Apply per-model recommended sampling defaults. Caller's explicit
# (non-None) kwargs win over the map field-by-field. With
# recommended_sampling=False (default) and an unknown model stem,
# apply_sampling_defaults returns an empty dict silently — which
# is the common case for hosted providers whose model identifiers
# aren't in forge's registry.
defaults = apply_sampling_defaults(self.model, strict=recommended_sampling)
defaults = apply_sampling_defaults(self.sampling_key, strict=recommended_sampling)
self.temperature = temperature if temperature is not None else defaults.get("temperature")
self.top_p = top_p if top_p is not None else defaults.get("top_p")
self.top_k = top_k if top_k is not None else defaults.get("top_k")
Expand Down
59 changes: 31 additions & 28 deletions src/forge/clients/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,18 @@ def __init__(
recommended_sampling: bool = False,
) -> None:
self.base_url = base_url
# model_path is the canonical identity, sent verbatim in the wire
# "model" field. self.model is the derived registry-lookup key. Both
# are set together so the (model_path, model) invariant holds — see
# _set_model_identity. Must run before apply_sampling_defaults below,
# which reads self.model.
# Two identity roles, set together (see _set_model_identity):
# self.model — the wire "model" field, sent verbatim. For vLLM
# this is the model path / HF repo id (or the
# served-model-name once discovered in external
# mode), which vLLM validates the request against.
# self.sampling_key — the derived registry-lookup key for
# apply_sampling_defaults below (must be set first).
self._set_model_identity(model_path)

# Apply per-model recommended sampling defaults. Caller's explicit
# (non-None) kwargs win over the map field-by-field.
defaults = apply_sampling_defaults(self.model, strict=recommended_sampling)
defaults = apply_sampling_defaults(self.sampling_key, strict=recommended_sampling)
self.temperature = temperature if temperature is not None else defaults.get("temperature")
self.top_p = top_p if top_p is not None else defaults.get("top_p")
self.top_k = top_k if top_k is not None else defaults.get("top_k")
Expand All @@ -93,33 +95,34 @@ async def aclose(self) -> None:
await self._http.aclose()

@staticmethod
def _derive_model_field(model_path: str) -> str:
"""Derive the sampling-registry lookup key from the canonical path.
def _derive_sampling_key(wire_id: str) -> str:
"""Derive the sampling-registry lookup key from the wire model id.

vLLM accepts either a local directory (safetensors + config) or an HF
repo id (e.g. "google/gemma-4-26B-A4B-it"). The lookup key uses the
path stem so registry lookups match the existing GGUF-stem convention:
vLLM's wire id is either a local directory (safetensors + config) or an
HF repo id (e.g. "google/gemma-4-26B-A4B-it"). The lookup key uses the
stem so registry lookups match the existing GGUF-stem convention:
a filesystem path → its directory name; an HF repo id (has "/") → its
trailing segment; anything else → the string unchanged.
"""
path_obj = Path(model_path)
path_obj = Path(wire_id)
if path_obj.is_absolute() or path_obj.exists():
return path_obj.name
if "/" in model_path:
return model_path.split("/")[-1]
return model_path

def _set_model_identity(self, model_path: str | Path) -> None:
"""Set both identity fields atomically from one canonical path.

``model_path`` is the wire "model" field (sent verbatim); ``model`` is
the derived registry key. Used by ``__init__`` and by the proxy's
external-mode served-name adoption, so the ``(model_path, model)``
invariant holds the same way in both — instead of mutating the two
fields separately after served-name discovery.
if "/" in wire_id:
return wire_id.split("/")[-1]
return wire_id

def _set_model_identity(self, wire_id: str | Path) -> None:
"""Set both identity fields atomically from one wire id.

``model`` is the wire "model" field (sent verbatim); ``sampling_key``
is the derived registry-lookup key. Used by ``__init__`` and by the
proxy's external-mode served-name adoption, so the
``(model, sampling_key)`` invariant holds the same way in both —
instead of mutating the two fields separately after served-name
discovery.
"""
self.model_path = str(model_path)
self.model = self._derive_model_field(self.model_path)
self.model = str(wire_id)
self.sampling_key = self._derive_sampling_key(self.model)

# Sampling fields recognized in per-call overrides. ``seed`` is
# accepted only as a per-call override (not an instance field).
Expand Down Expand Up @@ -194,7 +197,7 @@ async def send(
reasoning server-side and is native-only.
"""
body: dict[str, Any] = {
"model": self.model_path,
"model": self.model,
"messages": messages,
"stream": False,
}
Expand Down Expand Up @@ -245,7 +248,7 @@ async def send_stream(
accepted for protocol symmetry and ignored (see ``send``).
"""
body: dict[str, Any] = {
"model": self.model_path,
"model": self.model,
"messages": messages,
"stream": True,
"stream_options": {"include_usage": True},
Expand Down
4 changes: 2 additions & 2 deletions src/forge/proxy/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,10 @@ async def _handle_health(self, writer: asyncio.StreamWriter) -> None:
await self._send_json(writer, 200, body)

async def _handle_models(self, writer: asyncio.StreamWriter) -> None:
"""GET /v1/models — returns a minimal model list."""
"""GET /v1/models — report the backend model the proxy is fronting."""
body = json.dumps({
"object": "list",
"data": [{"id": "forge", "object": "model"}],
"data": [{"id": self._client.model, "object": "model"}],
})
await self._send_json(writer, 200, body)

Expand Down
12 changes: 6 additions & 6 deletions tests/unit/test_proxy_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,14 @@ async def test_vllm_adopts_served_model_name(self) -> None:
new_callable=AsyncMock, return_value="my-awq-model",
):
client, _ = await proxy._setup_external()
assert client.model_path == "my-awq-model"
assert client.model == "my-awq-model"
assert client.sampling_key == "my-awq-model"

@pytest.mark.asyncio
async def test_vllm_served_repo_id_keeps_wire_path_derives_registry_key(self) -> None:
# An HF-repo-id served name must reach the wire verbatim (vLLM validates
# it), while the registry key is the derived stem — the (model_path,
# model) invariant, applied to served-name adoption.
# it), while the registry key is the derived stem — the (model,
# sampling_key) invariant, applied to served-name adoption.
proxy = ProxyServer(
backend_url="http://localhost:8000", backend="vllm", budget_tokens=8192,
)
Expand All @@ -158,8 +158,8 @@ async def test_vllm_served_repo_id_keeps_wire_path_derives_registry_key(self) ->
new_callable=AsyncMock, return_value="google/gemma-4-26B-A4B-it",
):
client, _ = await proxy._setup_external()
assert client.model_path == "google/gemma-4-26B-A4B-it"
assert client.model == "gemma-4-26B-A4B-it"
assert client.model == "google/gemma-4-26B-A4B-it"
assert client.sampling_key == "gemma-4-26B-A4B-it"

@pytest.mark.asyncio
async def test_vllm_keeps_placeholder_when_discovery_fails(self) -> None:
Expand All @@ -170,7 +170,7 @@ async def test_vllm_keeps_placeholder_when_discovery_fails(self) -> None:
VLLMClient, "get_served_model_name", new_callable=AsyncMock, return_value=None,
):
client, _ = await proxy._setup_external()
assert client.model_path == "default"
assert client.model == "default"

@pytest.mark.asyncio
async def test_url_v1_suffix_preserved(self) -> None:
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/test_proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def _mock_client(response):
"""Create a mock LLMClient that returns the given response."""
client = AsyncMock()
client.api_format = "ollama"
client.model = "mock-model"
client.send = AsyncMock(return_value=response)
return client

Expand Down Expand Up @@ -143,7 +144,7 @@ async def test_models_endpoint(self, server_factory):
assert status == 200
data = json.loads(body)
assert data["object"] == "list"
assert len(data["data"]) > 0
assert data["data"][0]["id"] == "mock-model"

@pytest.mark.asyncio
async def test_not_found(self, server_factory):
Expand Down
14 changes: 8 additions & 6 deletions tests/unit/test_vllm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,19 +90,21 @@ def _text_response(content: str = "hi", reasoning: str | None = None) -> dict:


class TestConstructor:
def test_directory_path_derives_model_from_dirname(self) -> None:
def test_directory_path_derives_sampling_key_from_dirname(self) -> None:
c = VLLMClient(model_path="/models/gemma-4-26B-A4B-it-AWQ-4bit")
assert c.model == "gemma-4-26B-A4B-it-AWQ-4bit"
assert c.model_path == "/models/gemma-4-26B-A4B-it-AWQ-4bit"
# model is the wire id (the path verbatim); sampling_key is the stem.
assert c.model == "/models/gemma-4-26B-A4B-it-AWQ-4bit"
assert c.sampling_key == "gemma-4-26B-A4B-it-AWQ-4bit"

def test_hf_repo_id_derives_model_from_trailing_segment(self) -> None:
def test_hf_repo_id_derives_sampling_key_from_trailing_segment(self) -> None:
c = VLLMClient(model_path="google/gemma-4-26B-A4B-it")
assert c.model == "gemma-4-26B-A4B-it"
assert c.model_path == "google/gemma-4-26B-A4B-it"
assert c.model == "google/gemma-4-26B-A4B-it"
assert c.sampling_key == "gemma-4-26B-A4B-it"

def test_single_token_model_path(self) -> None:
c = VLLMClient(model_path="some-local-name")
assert c.model == "some-local-name"
assert c.sampling_key == "some-local-name"

def test_api_format_is_openai(self) -> None:
c = VLLMClient(model_path="/models/x")
Expand Down
Loading