Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/forge/clients/llamafile.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,18 @@ def __init__(
cache_prompt: bool = True,
slot_id: int | None = None,
recommended_sampling: bool = False,
api_key: str | None = None,
) -> None:
self.base_url = base_url
# gguf_path is the canonical identity. self.model is the stem (no
# .gguf / .llamafile suffix) — used for the wire-format model field
# (llama-server ignores it but it flows into eval JSONL rows) and
# for sampling-defaults lookup.
self.gguf_path = Path(gguf_path)
self.model = self.gguf_path.stem
# When gguf_path is a real file, use stem (no .gguf/.llamafile suffix).
# When it's a plain model name (no GGUF extension), use it as-is.
_GGUF_SUFFIXES = {".gguf", ".llamafile"}
self.model = self.gguf_path.stem if self.gguf_path.suffix.lower() in _GGUF_SUFFIXES else str(gguf_path)
# Apply per-model recommended sampling defaults. Caller's explicit
# (non-None) kwargs win over the map field-by-field.
defaults = apply_sampling_defaults(self.model, strict=recommended_sampling)
Expand All @@ -160,7 +164,8 @@ def __init__(
self.repeat_penalty = repeat_penalty if repeat_penalty is not None else defaults.get("repeat_penalty")
self.presence_penalty = presence_penalty if presence_penalty is not None else defaults.get("presence_penalty")
self.mode = mode
self._http = httpx.AsyncClient(timeout=timeout)
headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
self._http = httpx.AsyncClient(timeout=timeout, headers=headers)
self._think: bool = think if think is not None else True # auto = capture
self._cache_prompt = cache_prompt
self._slot_id = slot_id
Expand Down Expand Up @@ -407,6 +412,8 @@ async def get_context_length(self) -> int | None:
base = base[:-3]

resp = await self._http.get(f"{base}/props")
if resp.status_code == 404:
return None
resp.raise_for_status()
data = resp.json()

Expand Down
2 changes: 2 additions & 0 deletions src/forge/proxy/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def main() -> None:
parser.add_argument("--max-retries", type=int, default=3, help="Max retries per request (default: 3)")
parser.add_argument("--no-rescue", action="store_true", help="Disable rescue parsing")
parser.add_argument("--verbose", "-v", action="store_true", help="Verbose logging")
parser.add_argument("--api-key", help="API key for backend authentication (sent as Bearer token)")

args = parser.parse_args()

Expand Down Expand Up @@ -82,6 +83,7 @@ def main() -> None:
serialize=serialize,
max_retries=args.max_retries,
rescue_enabled=not args.no_rescue,
api_key=args.api_key,
)

def _shutdown(sig: int, _frame: object) -> None:
Expand Down
9 changes: 6 additions & 3 deletions src/forge/proxy/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
serialize: bool | None = None,
max_retries: int = 3,
rescue_enabled: bool = True,
api_key: str | None = None,
) -> None:
"""
Args:
Expand Down Expand Up @@ -92,6 +93,7 @@ def __init__(
self._port = port
self._max_retries = max_retries
self._rescue_enabled = rescue_enabled
self._api_key = api_key

# Auto-detect serialization: managed = single GPU = serialize
if serialize is None:
Expand Down Expand Up @@ -164,12 +166,13 @@ async def _async_start(self, ready: threading.Event) -> None:
if not base.endswith("/v1"):
base = base + "/v1"
# External mode: caller manages the backend, so we don't have a
# GGUF path. "default" is a placeholder identity for the wire
# model field (llama-server ignores it) and JSONL model field.
# GGUF path. Use model name if provided, else "default" placeholder.
wire_model = self._model or "default"
client = LlamafileClient(
gguf_path="default",
gguf_path=wire_model,
base_url=base,
mode="native",
api_key=self._api_key,
)
if self._budget_tokens is not None:
budget = self._budget_tokens
Expand Down