mirror of
https://github.com/nesquena/hermes-webui.git
synced 2026-07-04 14:41:05 +00:00
This commit is contained in:
+105
-33
@@ -26,12 +26,14 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
import http.client
|
||||
import socket as _socket
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from contextlib import closing
|
||||
from urllib.parse import parse_qs, quote, urljoin, urlsplit
|
||||
from urllib.error import HTTPError, URLError
|
||||
from urllib.request import HTTPRedirectHandler, ProxyHandler, Request, build_opener
|
||||
from urllib.request import HTTPRedirectHandler, HTTPSHandler, ProxyHandler, Request, build_opener
|
||||
from api.agent_sessions import (
|
||||
MESSAGING_SOURCES,
|
||||
_looks_like_default_cli_title,
|
||||
@@ -16299,6 +16301,24 @@ _TTS_PROXY_MAX_BYTES = 16 * 1024 * 1024
|
||||
_TTS_LOCALHOST_HOSTS = {"127.0.0.1", "::1", "localhost"}
|
||||
|
||||
|
||||
def _tts_addr_is_blocked(ip_str: str) -> bool:
|
||||
"""Return True when IP is in a private or otherwise non-routable class."""
|
||||
import ipaddress
|
||||
|
||||
try:
|
||||
ip = ipaddress.ip_address(ip_str)
|
||||
except ValueError:
|
||||
return False
|
||||
return (
|
||||
ip.is_private
|
||||
or ip.is_loopback
|
||||
or ip.is_link_local
|
||||
or ip.is_reserved
|
||||
or ip.is_multicast
|
||||
or ip.is_unspecified
|
||||
)
|
||||
|
||||
|
||||
def _tts_host_is_blocked_target(hostname: str) -> bool:
|
||||
"""True if the hostname resolves to (or literally is) a private / loopback /
|
||||
link-local / reserved / multicast address — the SSRF-risk targets that an
|
||||
@@ -16312,24 +16332,10 @@ def _tts_host_is_blocked_target(hostname: str) -> bool:
|
||||
if not host:
|
||||
return True
|
||||
|
||||
def _addr_blocked(ip_str: str) -> bool:
|
||||
try:
|
||||
ip = ipaddress.ip_address(ip_str)
|
||||
except ValueError:
|
||||
return False
|
||||
return (
|
||||
ip.is_private
|
||||
or ip.is_loopback
|
||||
or ip.is_link_local
|
||||
or ip.is_reserved
|
||||
or ip.is_multicast
|
||||
or ip.is_unspecified
|
||||
)
|
||||
|
||||
# Literal IP host?
|
||||
try:
|
||||
ipaddress.ip_address(host)
|
||||
return _addr_blocked(host)
|
||||
return _tts_addr_is_blocked(host)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
@@ -16346,11 +16352,42 @@ def _tts_host_is_blocked_target(hostname: str) -> bool:
|
||||
return False
|
||||
for info in infos:
|
||||
sockaddr = info[4]
|
||||
if sockaddr and _addr_blocked(str(sockaddr[0])):
|
||||
if sockaddr and _tts_addr_is_blocked(str(sockaddr[0])):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _tts_resolve_pinned_addresses(hostname: str, port: int | None) -> list[str]:
|
||||
"""Resolve once, validate the RRset, and preserve candidate dial order."""
|
||||
import socket
|
||||
|
||||
host = (hostname or "").strip().lower()
|
||||
if not host:
|
||||
raise ValueError("invalid OpenAI TTS base_url host")
|
||||
|
||||
try:
|
||||
infos = socket.getaddrinfo(host, port, type=socket.SOCK_STREAM)
|
||||
except Exception as exc:
|
||||
raise ValueError("could not resolve OpenAI TTS base_url host") from exc
|
||||
pinned_hosts = []
|
||||
for info in infos:
|
||||
sockaddr = info[4]
|
||||
if not sockaddr:
|
||||
continue
|
||||
pinned_host = str(sockaddr[0])
|
||||
if _tts_addr_is_blocked(pinned_host):
|
||||
raise ValueError("resolved OpenAI TTS target is not allowed")
|
||||
pinned_hosts.append(pinned_host)
|
||||
if not pinned_hosts:
|
||||
raise ValueError("could not resolve OpenAI TTS base_url host")
|
||||
return pinned_hosts
|
||||
|
||||
|
||||
def _tts_resolve_pinned_address(hostname: str) -> str:
|
||||
"""Return the first vetted literal address for direct helper callers."""
|
||||
return _tts_resolve_pinned_addresses(hostname, None)[0]
|
||||
|
||||
|
||||
def _normalized_openai_tts_base_url(base_url: str) -> str:
|
||||
from urllib.parse import urlsplit, urlunsplit
|
||||
|
||||
@@ -16410,6 +16447,52 @@ def _buffer_tts_audio_response(resp, *, max_bytes: int | None = None) -> bytes:
|
||||
raise ValueError("upstream audio exceeded byte limit")
|
||||
return bytes(audio_data)
|
||||
|
||||
class _NoRedirectTtsHandler(HTTPRedirectHandler):
|
||||
"""Refuse to follow redirects on the TTS call.
|
||||
|
||||
A redirect is never a legitimate response to POST /audio/speech and can
|
||||
carry the Authorization bearer to a target that bypasses the base_url check.
|
||||
"""
|
||||
|
||||
def redirect_request(self, req, fp, code, msg, headers, newurl):
|
||||
raise ValueError("OpenAI TTS upstream attempted a redirect")
|
||||
|
||||
|
||||
class _PinnedHTTPSConnection(http.client.HTTPSConnection):
|
||||
"""Connect to a pinned IP while keeping Host and TLS SNI on the hostname."""
|
||||
|
||||
def connect(self):
|
||||
sys.audit("http.client.connect", self, self.host, self.port)
|
||||
last_error = None
|
||||
for pinned_host in _tts_resolve_pinned_addresses(self.host, self.port):
|
||||
try:
|
||||
self.sock = _socket.create_connection(
|
||||
(pinned_host, self.port), self.timeout, self.source_address
|
||||
)
|
||||
break
|
||||
except OSError as exc:
|
||||
last_error = exc
|
||||
else:
|
||||
if last_error is not None:
|
||||
raise last_error
|
||||
raise OSError("could not connect to any pinned OpenAI TTS target")
|
||||
try:
|
||||
self.sock.setsockopt(_socket.IPPROTO_TCP, _socket.TCP_NODELAY, 1)
|
||||
except OSError as exc:
|
||||
if exc.errno != errno.ENOPROTOOPT:
|
||||
raise
|
||||
|
||||
if self._tunnel_host:
|
||||
self._tunnel()
|
||||
|
||||
server_hostname = self._tunnel_host or self.host
|
||||
self.sock = self._context.wrap_socket(self.sock, server_hostname=server_hostname)
|
||||
|
||||
|
||||
class _PinnedHTTPSHandler(HTTPSHandler):
|
||||
def https_open(self, req):
|
||||
return self.do_open(_PinnedHTTPSConnection, req, context=self._context)
|
||||
|
||||
|
||||
def _tts_open(req, *, timeout=30, opener_factory=None):
|
||||
"""Thin network seam for the TTS upstream fetch so tests can intercept it.
|
||||
@@ -16661,29 +16744,18 @@ def _handle_tts(handler, parsed):
|
||||
"voice": oai_voice,
|
||||
}).encode("utf-8")
|
||||
|
||||
from urllib.request import Request, build_opener, HTTPRedirectHandler, urlopen as _urlopen
|
||||
|
||||
class _NoRedirectTtsHandler(HTTPRedirectHandler):
|
||||
"""Refuse to follow redirects on the TTS call. A redirect is never a
|
||||
legitimate response to a POST /audio/speech, and following one would
|
||||
(a) carry the Authorization bearer to the redirect target and
|
||||
(b) let a public host bounce the request to a private/link-local
|
||||
SSRF target after the base-url validation already passed."""
|
||||
|
||||
def redirect_request(self, req, fp, code, msg, headers, newurl):
|
||||
raise ValueError("OpenAI TTS upstream attempted a redirect")
|
||||
|
||||
from urllib.request import Request, build_opener, urlopen as _urlopen
|
||||
req = Request(url, data=req_body, headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "audio/mpeg",
|
||||
})
|
||||
|
||||
# Use a no-redirect opener so an upstream redirect can't carry the bearer
|
||||
# to (or SSRF-bounce into) a different/private target. _tts_open is a thin
|
||||
# module seam so tests can still intercept the network call.
|
||||
# Use a pinned HTTPS opener so the resolved address is the one that gets
|
||||
# dialed. Keep the no-redirect handler in the same chain to block
|
||||
# bearer leaks and SSRF bounce redirects after hostname validation.
|
||||
try:
|
||||
with _tts_open(req, timeout=30, opener_factory=lambda: build_opener(_NoRedirectTtsHandler())) as resp:
|
||||
with _tts_open(req, timeout=30, opener_factory=lambda: build_opener(ProxyHandler({}), _NoRedirectTtsHandler(), _PinnedHTTPSHandler())) as resp:
|
||||
audio_data = _buffer_tts_audio_response(resp)
|
||||
except ValueError:
|
||||
logger.warning("OpenAI TTS rejected an invalid upstream response", exc_info=True)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""OpenAI-compatible TTS endpoint and UI wiring coverage for #4982."""
|
||||
import io
|
||||
import json
|
||||
import socket
|
||||
import ssl
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
@@ -55,6 +57,39 @@ class _StreamOnceResponse:
|
||||
return b""
|
||||
|
||||
|
||||
def _http_response_bytes(status_code: int, body=b"", *, reason="OK", headers=None):
|
||||
hdr = {"Content-Length": str(len(body)), "Content-Type": "audio/mpeg"}
|
||||
if headers:
|
||||
hdr.update(headers)
|
||||
lines = [f"HTTP/1.1 {status_code} {reason}\r\n"]
|
||||
for key, value in hdr.items():
|
||||
lines.append(f"{key}: {value}\r\n")
|
||||
lines.append("\r\n")
|
||||
return "".join(lines).encode("utf-8") + body
|
||||
|
||||
|
||||
class _FakeSocketForHttps:
|
||||
def __init__(self, response_body: bytes):
|
||||
self.writes = []
|
||||
self.response = io.BytesIO(response_body)
|
||||
self.closed = False
|
||||
|
||||
def sendall(self, data):
|
||||
self.writes.append(data)
|
||||
|
||||
def setsockopt(self, *_args, **_kwargs):
|
||||
return None
|
||||
|
||||
def makefile(self, *_args, **_kwargs):
|
||||
return self.response
|
||||
|
||||
def shutdown(self, *_args, **_kwargs):
|
||||
return None
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
|
||||
def _post(body_dict, **kw):
|
||||
body = json.dumps(body_dict).encode()
|
||||
return _FakeHandler(body, **kw)
|
||||
@@ -151,6 +186,233 @@ def test_openai_tts_config_overrides(monkeypatch):
|
||||
assert captured["body"] == {"model": "tts-custom", "input": "Hello", "voice": "nova"}
|
||||
|
||||
|
||||
def test_tts_resolve_pinned_address_accepts_public_ip(monkeypatch):
|
||||
def _fake_getaddrinfo(*_args, **_kwargs):
|
||||
return [(0, 0, 0, "", ("1.1.1.1", 0))]
|
||||
monkeypatch.setattr(socket, "getaddrinfo", _fake_getaddrinfo)
|
||||
assert routes._tts_resolve_pinned_address("1.1.1.1") == "1.1.1.1"
|
||||
|
||||
|
||||
def test_tts_resolve_pinned_address_rejects_blocked_target(monkeypatch):
|
||||
def _fake_getaddrinfo(*_args, **_kwargs):
|
||||
return [(0, 0, 0, "", ("10.0.0.5", 0))]
|
||||
monkeypatch.setattr(socket, "getaddrinfo", _fake_getaddrinfo)
|
||||
with pytest.raises(ValueError, match="not allowed"):
|
||||
routes._tts_resolve_pinned_address("public.example.com")
|
||||
|
||||
|
||||
def test_tts_resolve_pinned_address_rejects_mixed_addresses(monkeypatch):
|
||||
def _fake_getaddrinfo(*_args, **_kwargs):
|
||||
return [
|
||||
(0, 0, 0, "", ("203.0.113.10", 0)),
|
||||
(0, 0, 0, "", ("127.0.0.1", 0)),
|
||||
]
|
||||
monkeypatch.setattr(socket, "getaddrinfo", _fake_getaddrinfo)
|
||||
with pytest.raises(ValueError, match="not allowed"):
|
||||
routes._tts_resolve_pinned_address("public.example.com")
|
||||
|
||||
|
||||
def test_openai_tts_does_not_connect_to_rebound_private_address(monkeypatch):
|
||||
host = "rebind-openai.example.com"
|
||||
counts = {"getaddrinfo": 0}
|
||||
created = []
|
||||
|
||||
def _fake_getaddrinfo(*_args, **_kwargs):
|
||||
counts["getaddrinfo"] += 1
|
||||
if counts["getaddrinfo"] == 1:
|
||||
return [(0, 0, 0, "", ("1.1.1.1", 443))]
|
||||
return [(0, 0, 0, "", ("169.254.169.254", 443))]
|
||||
|
||||
def _fake_create_connection(address, *args, **_kwargs):
|
||||
dial_host, dial_port = address
|
||||
try:
|
||||
socket.inet_aton(dial_host)
|
||||
resolved = dial_host
|
||||
except OSError:
|
||||
resolved = socket.getaddrinfo(dial_host, dial_port)[0][4][0]
|
||||
created.append((resolved, dial_port))
|
||||
raise AssertionError(f"connect should not run with this test; got {(resolved, dial_port)}")
|
||||
|
||||
def _fake_wrap_socket(_context, sock, *args, **kwargs):
|
||||
return sock
|
||||
|
||||
import api.config as config
|
||||
|
||||
monkeypatch.setattr(socket, "getaddrinfo", _fake_getaddrinfo)
|
||||
monkeypatch.setattr(socket, "create_connection", _fake_create_connection)
|
||||
monkeypatch.setattr(ssl.SSLContext, "wrap_socket", _fake_wrap_socket)
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-openai")
|
||||
monkeypatch.setattr(config, "get_config", lambda: {
|
||||
"tts": {"openai": {"base_url": f"https://{host}/v1"}}
|
||||
})
|
||||
|
||||
h = _post({"text": "Hello", "engine": "openai"}, client="10.82.0.8")
|
||||
routes._handle_tts(h, None)
|
||||
|
||||
assert counts["getaddrinfo"] == 2
|
||||
assert created == []
|
||||
assert h.status == 502
|
||||
assert "OpenAI TTS generation failed" in (h.payload() or {}).get("error", "")
|
||||
|
||||
|
||||
def test_openai_tts_pinned_connection_preserves_host_and_sni(monkeypatch):
|
||||
host = "static-openai.example.com"
|
||||
response_bytes = _http_response_bytes(200, b"audio-openai")
|
||||
fake_socket = _FakeSocketForHttps(response_bytes)
|
||||
observed = {}
|
||||
created = []
|
||||
|
||||
def _fake_getaddrinfo(*_args, **_kwargs):
|
||||
return [(0, 0, 0, "", ("1.1.1.1", 443))]
|
||||
|
||||
def _fake_create_connection(address, *_args, **_kwargs):
|
||||
created.append(address)
|
||||
return fake_socket
|
||||
|
||||
def _fake_wrap_socket(_context, sock, *args, server_hostname=None, **_kwargs):
|
||||
observed["server_hostname"] = server_hostname
|
||||
return sock
|
||||
|
||||
import api.config as config
|
||||
|
||||
monkeypatch.setattr(socket, "getaddrinfo", _fake_getaddrinfo)
|
||||
monkeypatch.setattr(socket, "create_connection", _fake_create_connection)
|
||||
monkeypatch.setattr(ssl.SSLContext, "wrap_socket", _fake_wrap_socket)
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-openai")
|
||||
monkeypatch.setattr(config, "get_config", lambda: {
|
||||
"tts": {"openai": {"base_url": f"https://{host}/v1"}}
|
||||
})
|
||||
h = _post({"text": "Hello", "engine": "openai"}, client="10.82.0.9")
|
||||
routes._handle_tts(h, None)
|
||||
|
||||
assert h.status == 200
|
||||
assert h.sent_headers["Content-Type"] == "audio/mpeg"
|
||||
assert h.wfile.getvalue() == b"audio-openai"
|
||||
assert created == [("1.1.1.1", 443)]
|
||||
assert observed["server_hostname"] == host
|
||||
sent = b"".join(fake_socket.writes).decode("utf-8", "replace")
|
||||
assert f"Host: {host}" in sent
|
||||
|
||||
|
||||
def test_openai_tts_pinned_connection_tries_later_vetted_candidate(monkeypatch):
|
||||
host = "multi-openai.example.com"
|
||||
response_bytes = _http_response_bytes(200, b"audio-openai")
|
||||
fake_socket = _FakeSocketForHttps(response_bytes)
|
||||
observed = {"getaddrinfo": 0, "server_hostname": None}
|
||||
created = []
|
||||
|
||||
def _fake_getaddrinfo(target_host, target_port=None, *_args, **_kwargs):
|
||||
observed["getaddrinfo"] += 1
|
||||
assert target_host == host
|
||||
port = target_port or 443
|
||||
return [
|
||||
(socket.AF_INET6, socket.SOCK_STREAM, 6, "", ("2606:4700:4700::1111", port, 0, 0)),
|
||||
(socket.AF_INET, socket.SOCK_STREAM, 6, "", ("1.1.1.1", port)),
|
||||
]
|
||||
|
||||
def _fake_create_connection(address, *_args, **_kwargs):
|
||||
created.append(address)
|
||||
if address[0] == "2606:4700:4700::1111":
|
||||
raise OSError("ipv6 unavailable")
|
||||
if address[0] == "1.1.1.1":
|
||||
return fake_socket
|
||||
raise AssertionError(f"unexpected connect target {address}")
|
||||
|
||||
def _fake_wrap_socket(_context, sock, *args, server_hostname=None, **_kwargs):
|
||||
observed["server_hostname"] = server_hostname
|
||||
return sock
|
||||
|
||||
import api.config as config
|
||||
|
||||
monkeypatch.setattr(socket, "getaddrinfo", _fake_getaddrinfo)
|
||||
monkeypatch.setattr(socket, "create_connection", _fake_create_connection)
|
||||
monkeypatch.setattr(ssl.SSLContext, "wrap_socket", _fake_wrap_socket)
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-openai")
|
||||
monkeypatch.setattr(config, "get_config", lambda: {
|
||||
"tts": {"openai": {"base_url": f"https://{host}/v1"}}
|
||||
})
|
||||
h = _post({"text": "Hello", "engine": "openai"}, client="10.82.0.12")
|
||||
routes._handle_tts(h, None)
|
||||
|
||||
assert h.status == 200
|
||||
assert h.wfile.getvalue() == b"audio-openai"
|
||||
assert created == [("2606:4700:4700::1111", 443), ("1.1.1.1", 443)]
|
||||
assert observed["getaddrinfo"] == 2
|
||||
assert observed["server_hostname"] == host
|
||||
|
||||
|
||||
def test_openai_tts_rejects_redirect_with_pinned_opener(monkeypatch):
|
||||
host = "redirect-openai.example.com"
|
||||
response_bytes = _http_response_bytes(302, headers={"Location": "http://169.254.169.254/v1/audio/speech"}, reason="Found")
|
||||
fake_socket = _FakeSocketForHttps(response_bytes)
|
||||
created = []
|
||||
|
||||
def _fake_getaddrinfo(*_args, **_kwargs):
|
||||
return [(0, 0, 0, "", ("1.1.1.1", 443))]
|
||||
|
||||
def _fake_create_connection(address, *_args, **_kwargs):
|
||||
created.append(address)
|
||||
return fake_socket
|
||||
|
||||
def _fake_wrap_socket(_context, sock, *args, **kwargs):
|
||||
return sock
|
||||
|
||||
import api.config as config
|
||||
|
||||
monkeypatch.setattr(socket, "getaddrinfo", _fake_getaddrinfo)
|
||||
monkeypatch.setattr(socket, "create_connection", _fake_create_connection)
|
||||
monkeypatch.setattr(ssl.SSLContext, "wrap_socket", _fake_wrap_socket)
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-openai")
|
||||
monkeypatch.setattr(config, "get_config", lambda: {
|
||||
"tts": {"openai": {"base_url": f"https://{host}/v1"}}
|
||||
})
|
||||
h = _post({"text": "Hello", "engine": "openai"}, client="10.82.0.10")
|
||||
routes._handle_tts(h, None)
|
||||
|
||||
assert h.status in (500, 502)
|
||||
assert created == [("1.1.1.1", 443)]
|
||||
assert "OpenAI TTS generation failed" in (h.payload() or {}).get("error", "")
|
||||
|
||||
|
||||
def test_openai_tts_ignores_https_proxy_and_dials_pinned_target(monkeypatch):
|
||||
host = "proxy-safe-openai.example.com"
|
||||
response_bytes = _http_response_bytes(200, b"audio-openai")
|
||||
fake_socket = _FakeSocketForHttps(response_bytes)
|
||||
created = []
|
||||
|
||||
def _fake_getaddrinfo(target_host, *_args, **_kwargs):
|
||||
if target_host == host:
|
||||
return [(0, 0, 0, "", ("1.1.1.1", 443))]
|
||||
if target_host == "127.0.0.1":
|
||||
return [(0, 0, 0, "", ("127.0.0.1", 8888))]
|
||||
raise AssertionError(f"unexpected DNS lookup for {target_host}")
|
||||
|
||||
def _fake_create_connection(address, *_args, **_kwargs):
|
||||
created.append(address)
|
||||
return fake_socket
|
||||
|
||||
def _fake_wrap_socket(_context, sock, *args, **kwargs):
|
||||
return sock
|
||||
|
||||
import api.config as config
|
||||
|
||||
monkeypatch.setattr(socket, "getaddrinfo", _fake_getaddrinfo)
|
||||
monkeypatch.setattr(socket, "create_connection", _fake_create_connection)
|
||||
monkeypatch.setattr(ssl.SSLContext, "wrap_socket", _fake_wrap_socket)
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-openai")
|
||||
monkeypatch.setenv("HTTPS_PROXY", "http://127.0.0.1:8888")
|
||||
monkeypatch.setenv("https_proxy", "http://127.0.0.1:8888")
|
||||
monkeypatch.setattr(config, "get_config", lambda: {
|
||||
"tts": {"openai": {"base_url": f"https://{host}/v1"}}
|
||||
})
|
||||
h = _post({"text": "Hello", "engine": "openai"}, client="10.82.0.11")
|
||||
routes._handle_tts(h, None)
|
||||
|
||||
assert h.status == 200
|
||||
assert h.wfile.getvalue() == b"audio-openai"
|
||||
assert created == [("1.1.1.1", 443)]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("base_url", [
|
||||
"http://169.254.169.254/v1",
|
||||
"https://user:pass@api.example.com/v1",
|
||||
|
||||
Reference in New Issue
Block a user