diff --git a/src/forge/clients/llamafile.py b/src/forge/clients/llamafile.py index f8082cc..481f0bd 100644 --- a/src/forge/clients/llamafile.py +++ b/src/forge/clients/llamafile.py @@ -142,6 +142,7 @@ def __init__( cache_prompt: bool = True, slot_id: int | None = None, recommended_sampling: bool = False, + api_key: str | None = None, ) -> None: self.base_url = base_url # gguf_path is the canonical identity. self.model is the stem (no @@ -149,7 +150,10 @@ def __init__( # (llama-server ignores it but it flows into eval JSONL rows) and # for sampling-defaults lookup. self.gguf_path = Path(gguf_path) - self.model = self.gguf_path.stem + # When gguf_path is a real file, use stem (no .gguf/.llamafile suffix). + # When it's a plain model name (no GGUF extension), use it as-is. + _GGUF_SUFFIXES = {".gguf", ".llamafile"} + self.model = self.gguf_path.stem if self.gguf_path.suffix.lower() in _GGUF_SUFFIXES else str(gguf_path) # Apply per-model recommended sampling defaults. Caller's explicit # (non-None) kwargs win over the map field-by-field. defaults = apply_sampling_defaults(self.model, strict=recommended_sampling) @@ -160,7 +164,8 @@ def __init__( self.repeat_penalty = repeat_penalty if repeat_penalty is not None else defaults.get("repeat_penalty") self.presence_penalty = presence_penalty if presence_penalty is not None else defaults.get("presence_penalty") self.mode = mode - self._http = httpx.AsyncClient(timeout=timeout) + headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} + self._http = httpx.AsyncClient(timeout=timeout, headers=headers) self._think: bool = think if think is not None else True # auto = capture self._cache_prompt = cache_prompt self._slot_id = slot_id @@ -407,6 +412,8 @@ async def get_context_length(self) -> int | None: base = base[:-3] resp = await self._http.get(f"{base}/props") + if resp.status_code == 404: + return None resp.raise_for_status() data = resp.json() diff --git a/src/forge/proxy/__main__.py b/src/forge/proxy/__main__.py index efb4c65..774a4d2 100644 --- a/src/forge/proxy/__main__.py +++ b/src/forge/proxy/__main__.py @@ -50,6 +50,7 @@ def main() -> None: parser.add_argument("--max-retries", type=int, default=3, help="Max retries per request (default: 3)") parser.add_argument("--no-rescue", action="store_true", help="Disable rescue parsing") parser.add_argument("--verbose", "-v", action="store_true", help="Verbose logging") + parser.add_argument("--api-key", help="API key for backend authentication (sent as Bearer token)") args = parser.parse_args() @@ -82,6 +83,7 @@ def main() -> None: serialize=serialize, max_retries=args.max_retries, rescue_enabled=not args.no_rescue, + api_key=args.api_key, ) def _shutdown(sig: int, _frame: object) -> None: diff --git a/src/forge/proxy/proxy.py b/src/forge/proxy/proxy.py index 330d138..f99b846 100644 --- a/src/forge/proxy/proxy.py +++ b/src/forge/proxy/proxy.py @@ -59,6 +59,7 @@ def __init__( serialize: bool | None = None, max_retries: int = 3, rescue_enabled: bool = True, + api_key: str | None = None, ) -> None: """ Args: @@ -92,6 +93,7 @@ def __init__( self._port = port self._max_retries = max_retries self._rescue_enabled = rescue_enabled + self._api_key = api_key # Auto-detect serialization: managed = single GPU = serialize if serialize is None: @@ -164,12 +166,13 @@ async def _async_start(self, ready: threading.Event) -> None: if not base.endswith("/v1"): base = base + "/v1" # External mode: caller manages the backend, so we don't have a - # GGUF path. "default" is a placeholder identity for the wire - # model field (llama-server ignores it) and JSONL model field. + # GGUF path. Use model name if provided, else "default" placeholder. + wire_model = self._model or "default" client = LlamafileClient( - gguf_path="default", + gguf_path=wire_model, base_url=base, mode="native", + api_key=self._api_key, ) if self._budget_tokens is not None: budget = self._budget_tokens