diff --git a/api/routes.py b/api/routes.py index 9dab83317..3de180a98 100644 --- a/api/routes.py +++ b/api/routes.py @@ -26,12 +26,14 @@ import sys import threading import time import uuid +import http.client +import socket as _socket from collections import defaultdict from pathlib import Path from contextlib import closing from urllib.parse import parse_qs, quote, urljoin, urlsplit from urllib.error import HTTPError, URLError -from urllib.request import HTTPRedirectHandler, ProxyHandler, Request, build_opener +from urllib.request import HTTPRedirectHandler, HTTPSHandler, ProxyHandler, Request, build_opener from api.agent_sessions import ( MESSAGING_SOURCES, _looks_like_default_cli_title, @@ -16299,6 +16301,24 @@ _TTS_PROXY_MAX_BYTES = 16 * 1024 * 1024 _TTS_LOCALHOST_HOSTS = {"127.0.0.1", "::1", "localhost"} +def _tts_addr_is_blocked(ip_str: str) -> bool: + """Return True when IP is in a private or otherwise non-routable class.""" + import ipaddress + + try: + ip = ipaddress.ip_address(ip_str) + except ValueError: + return False + return ( + ip.is_private + or ip.is_loopback + or ip.is_link_local + or ip.is_reserved + or ip.is_multicast + or ip.is_unspecified + ) + + def _tts_host_is_blocked_target(hostname: str) -> bool: """True if the hostname resolves to (or literally is) a private / loopback / link-local / reserved / multicast address — the SSRF-risk targets that an @@ -16312,24 +16332,10 @@ def _tts_host_is_blocked_target(hostname: str) -> bool: if not host: return True - def _addr_blocked(ip_str: str) -> bool: - try: - ip = ipaddress.ip_address(ip_str) - except ValueError: - return False - return ( - ip.is_private - or ip.is_loopback - or ip.is_link_local - or ip.is_reserved - or ip.is_multicast - or ip.is_unspecified - ) - # Literal IP host? try: ipaddress.ip_address(host) - return _addr_blocked(host) + return _tts_addr_is_blocked(host) except ValueError: pass @@ -16346,11 +16352,42 @@ def _tts_host_is_blocked_target(hostname: str) -> bool: return False for info in infos: sockaddr = info[4] - if sockaddr and _addr_blocked(str(sockaddr[0])): + if sockaddr and _tts_addr_is_blocked(str(sockaddr[0])): return True return False +def _tts_resolve_pinned_addresses(hostname: str, port: int | None) -> list[str]: + """Resolve once, validate the RRset, and preserve candidate dial order.""" + import socket + + host = (hostname or "").strip().lower() + if not host: + raise ValueError("invalid OpenAI TTS base_url host") + + try: + infos = socket.getaddrinfo(host, port, type=socket.SOCK_STREAM) + except Exception as exc: + raise ValueError("could not resolve OpenAI TTS base_url host") from exc + pinned_hosts = [] + for info in infos: + sockaddr = info[4] + if not sockaddr: + continue + pinned_host = str(sockaddr[0]) + if _tts_addr_is_blocked(pinned_host): + raise ValueError("resolved OpenAI TTS target is not allowed") + pinned_hosts.append(pinned_host) + if not pinned_hosts: + raise ValueError("could not resolve OpenAI TTS base_url host") + return pinned_hosts + + +def _tts_resolve_pinned_address(hostname: str) -> str: + """Return the first vetted literal address for direct helper callers.""" + return _tts_resolve_pinned_addresses(hostname, None)[0] + + def _normalized_openai_tts_base_url(base_url: str) -> str: from urllib.parse import urlsplit, urlunsplit @@ -16410,6 +16447,52 @@ def _buffer_tts_audio_response(resp, *, max_bytes: int | None = None) -> bytes: raise ValueError("upstream audio exceeded byte limit") return bytes(audio_data) +class _NoRedirectTtsHandler(HTTPRedirectHandler): + """Refuse to follow redirects on the TTS call. + + A redirect is never a legitimate response to POST /audio/speech and can + carry the Authorization bearer to a target that bypasses the base_url check. + """ + + def redirect_request(self, req, fp, code, msg, headers, newurl): + raise ValueError("OpenAI TTS upstream attempted a redirect") + + +class _PinnedHTTPSConnection(http.client.HTTPSConnection): + """Connect to a pinned IP while keeping Host and TLS SNI on the hostname.""" + + def connect(self): + sys.audit("http.client.connect", self, self.host, self.port) + last_error = None + for pinned_host in _tts_resolve_pinned_addresses(self.host, self.port): + try: + self.sock = _socket.create_connection( + (pinned_host, self.port), self.timeout, self.source_address + ) + break + except OSError as exc: + last_error = exc + else: + if last_error is not None: + raise last_error + raise OSError("could not connect to any pinned OpenAI TTS target") + try: + self.sock.setsockopt(_socket.IPPROTO_TCP, _socket.TCP_NODELAY, 1) + except OSError as exc: + if exc.errno != errno.ENOPROTOOPT: + raise + + if self._tunnel_host: + self._tunnel() + + server_hostname = self._tunnel_host or self.host + self.sock = self._context.wrap_socket(self.sock, server_hostname=server_hostname) + + +class _PinnedHTTPSHandler(HTTPSHandler): + def https_open(self, req): + return self.do_open(_PinnedHTTPSConnection, req, context=self._context) + def _tts_open(req, *, timeout=30, opener_factory=None): """Thin network seam for the TTS upstream fetch so tests can intercept it. @@ -16661,29 +16744,18 @@ def _handle_tts(handler, parsed): "voice": oai_voice, }).encode("utf-8") - from urllib.request import Request, build_opener, HTTPRedirectHandler, urlopen as _urlopen - - class _NoRedirectTtsHandler(HTTPRedirectHandler): - """Refuse to follow redirects on the TTS call. A redirect is never a - legitimate response to a POST /audio/speech, and following one would - (a) carry the Authorization bearer to the redirect target and - (b) let a public host bounce the request to a private/link-local - SSRF target after the base-url validation already passed.""" - - def redirect_request(self, req, fp, code, msg, headers, newurl): - raise ValueError("OpenAI TTS upstream attempted a redirect") - + from urllib.request import Request, build_opener, urlopen as _urlopen req = Request(url, data=req_body, headers={ "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", "Accept": "audio/mpeg", }) - # Use a no-redirect opener so an upstream redirect can't carry the bearer - # to (or SSRF-bounce into) a different/private target. _tts_open is a thin - # module seam so tests can still intercept the network call. + # Use a pinned HTTPS opener so the resolved address is the one that gets + # dialed. Keep the no-redirect handler in the same chain to block + # bearer leaks and SSRF bounce redirects after hostname validation. try: - with _tts_open(req, timeout=30, opener_factory=lambda: build_opener(_NoRedirectTtsHandler())) as resp: + with _tts_open(req, timeout=30, opener_factory=lambda: build_opener(ProxyHandler({}), _NoRedirectTtsHandler(), _PinnedHTTPSHandler())) as resp: audio_data = _buffer_tts_audio_response(resp) except ValueError: logger.warning("OpenAI TTS rejected an invalid upstream response", exc_info=True) diff --git a/tests/test_issue4982_openai_tts.py b/tests/test_issue4982_openai_tts.py index 39c61ae0d..c910af0a7 100644 --- a/tests/test_issue4982_openai_tts.py +++ b/tests/test_issue4982_openai_tts.py @@ -1,6 +1,8 @@ """OpenAI-compatible TTS endpoint and UI wiring coverage for #4982.""" import io import json +import socket +import ssl from pathlib import Path import pytest @@ -55,6 +57,39 @@ class _StreamOnceResponse: return b"" +def _http_response_bytes(status_code: int, body=b"", *, reason="OK", headers=None): + hdr = {"Content-Length": str(len(body)), "Content-Type": "audio/mpeg"} + if headers: + hdr.update(headers) + lines = [f"HTTP/1.1 {status_code} {reason}\r\n"] + for key, value in hdr.items(): + lines.append(f"{key}: {value}\r\n") + lines.append("\r\n") + return "".join(lines).encode("utf-8") + body + + +class _FakeSocketForHttps: + def __init__(self, response_body: bytes): + self.writes = [] + self.response = io.BytesIO(response_body) + self.closed = False + + def sendall(self, data): + self.writes.append(data) + + def setsockopt(self, *_args, **_kwargs): + return None + + def makefile(self, *_args, **_kwargs): + return self.response + + def shutdown(self, *_args, **_kwargs): + return None + + def close(self): + self.closed = True + + def _post(body_dict, **kw): body = json.dumps(body_dict).encode() return _FakeHandler(body, **kw) @@ -151,6 +186,233 @@ def test_openai_tts_config_overrides(monkeypatch): assert captured["body"] == {"model": "tts-custom", "input": "Hello", "voice": "nova"} +def test_tts_resolve_pinned_address_accepts_public_ip(monkeypatch): + def _fake_getaddrinfo(*_args, **_kwargs): + return [(0, 0, 0, "", ("1.1.1.1", 0))] + monkeypatch.setattr(socket, "getaddrinfo", _fake_getaddrinfo) + assert routes._tts_resolve_pinned_address("1.1.1.1") == "1.1.1.1" + + +def test_tts_resolve_pinned_address_rejects_blocked_target(monkeypatch): + def _fake_getaddrinfo(*_args, **_kwargs): + return [(0, 0, 0, "", ("10.0.0.5", 0))] + monkeypatch.setattr(socket, "getaddrinfo", _fake_getaddrinfo) + with pytest.raises(ValueError, match="not allowed"): + routes._tts_resolve_pinned_address("public.example.com") + + +def test_tts_resolve_pinned_address_rejects_mixed_addresses(monkeypatch): + def _fake_getaddrinfo(*_args, **_kwargs): + return [ + (0, 0, 0, "", ("203.0.113.10", 0)), + (0, 0, 0, "", ("127.0.0.1", 0)), + ] + monkeypatch.setattr(socket, "getaddrinfo", _fake_getaddrinfo) + with pytest.raises(ValueError, match="not allowed"): + routes._tts_resolve_pinned_address("public.example.com") + + +def test_openai_tts_does_not_connect_to_rebound_private_address(monkeypatch): + host = "rebind-openai.example.com" + counts = {"getaddrinfo": 0} + created = [] + + def _fake_getaddrinfo(*_args, **_kwargs): + counts["getaddrinfo"] += 1 + if counts["getaddrinfo"] == 1: + return [(0, 0, 0, "", ("1.1.1.1", 443))] + return [(0, 0, 0, "", ("169.254.169.254", 443))] + + def _fake_create_connection(address, *args, **_kwargs): + dial_host, dial_port = address + try: + socket.inet_aton(dial_host) + resolved = dial_host + except OSError: + resolved = socket.getaddrinfo(dial_host, dial_port)[0][4][0] + created.append((resolved, dial_port)) + raise AssertionError(f"connect should not run with this test; got {(resolved, dial_port)}") + + def _fake_wrap_socket(_context, sock, *args, **kwargs): + return sock + + import api.config as config + + monkeypatch.setattr(socket, "getaddrinfo", _fake_getaddrinfo) + monkeypatch.setattr(socket, "create_connection", _fake_create_connection) + monkeypatch.setattr(ssl.SSLContext, "wrap_socket", _fake_wrap_socket) + monkeypatch.setenv("OPENAI_API_KEY", "sk-openai") + monkeypatch.setattr(config, "get_config", lambda: { + "tts": {"openai": {"base_url": f"https://{host}/v1"}} + }) + + h = _post({"text": "Hello", "engine": "openai"}, client="10.82.0.8") + routes._handle_tts(h, None) + + assert counts["getaddrinfo"] == 2 + assert created == [] + assert h.status == 502 + assert "OpenAI TTS generation failed" in (h.payload() or {}).get("error", "") + + +def test_openai_tts_pinned_connection_preserves_host_and_sni(monkeypatch): + host = "static-openai.example.com" + response_bytes = _http_response_bytes(200, b"audio-openai") + fake_socket = _FakeSocketForHttps(response_bytes) + observed = {} + created = [] + + def _fake_getaddrinfo(*_args, **_kwargs): + return [(0, 0, 0, "", ("1.1.1.1", 443))] + + def _fake_create_connection(address, *_args, **_kwargs): + created.append(address) + return fake_socket + + def _fake_wrap_socket(_context, sock, *args, server_hostname=None, **_kwargs): + observed["server_hostname"] = server_hostname + return sock + + import api.config as config + + monkeypatch.setattr(socket, "getaddrinfo", _fake_getaddrinfo) + monkeypatch.setattr(socket, "create_connection", _fake_create_connection) + monkeypatch.setattr(ssl.SSLContext, "wrap_socket", _fake_wrap_socket) + monkeypatch.setenv("OPENAI_API_KEY", "sk-openai") + monkeypatch.setattr(config, "get_config", lambda: { + "tts": {"openai": {"base_url": f"https://{host}/v1"}} + }) + h = _post({"text": "Hello", "engine": "openai"}, client="10.82.0.9") + routes._handle_tts(h, None) + + assert h.status == 200 + assert h.sent_headers["Content-Type"] == "audio/mpeg" + assert h.wfile.getvalue() == b"audio-openai" + assert created == [("1.1.1.1", 443)] + assert observed["server_hostname"] == host + sent = b"".join(fake_socket.writes).decode("utf-8", "replace") + assert f"Host: {host}" in sent + + +def test_openai_tts_pinned_connection_tries_later_vetted_candidate(monkeypatch): + host = "multi-openai.example.com" + response_bytes = _http_response_bytes(200, b"audio-openai") + fake_socket = _FakeSocketForHttps(response_bytes) + observed = {"getaddrinfo": 0, "server_hostname": None} + created = [] + + def _fake_getaddrinfo(target_host, target_port=None, *_args, **_kwargs): + observed["getaddrinfo"] += 1 + assert target_host == host + port = target_port or 443 + return [ + (socket.AF_INET6, socket.SOCK_STREAM, 6, "", ("2606:4700:4700::1111", port, 0, 0)), + (socket.AF_INET, socket.SOCK_STREAM, 6, "", ("1.1.1.1", port)), + ] + + def _fake_create_connection(address, *_args, **_kwargs): + created.append(address) + if address[0] == "2606:4700:4700::1111": + raise OSError("ipv6 unavailable") + if address[0] == "1.1.1.1": + return fake_socket + raise AssertionError(f"unexpected connect target {address}") + + def _fake_wrap_socket(_context, sock, *args, server_hostname=None, **_kwargs): + observed["server_hostname"] = server_hostname + return sock + + import api.config as config + + monkeypatch.setattr(socket, "getaddrinfo", _fake_getaddrinfo) + monkeypatch.setattr(socket, "create_connection", _fake_create_connection) + monkeypatch.setattr(ssl.SSLContext, "wrap_socket", _fake_wrap_socket) + monkeypatch.setenv("OPENAI_API_KEY", "sk-openai") + monkeypatch.setattr(config, "get_config", lambda: { + "tts": {"openai": {"base_url": f"https://{host}/v1"}} + }) + h = _post({"text": "Hello", "engine": "openai"}, client="10.82.0.12") + routes._handle_tts(h, None) + + assert h.status == 200 + assert h.wfile.getvalue() == b"audio-openai" + assert created == [("2606:4700:4700::1111", 443), ("1.1.1.1", 443)] + assert observed["getaddrinfo"] == 2 + assert observed["server_hostname"] == host + + +def test_openai_tts_rejects_redirect_with_pinned_opener(monkeypatch): + host = "redirect-openai.example.com" + response_bytes = _http_response_bytes(302, headers={"Location": "http://169.254.169.254/v1/audio/speech"}, reason="Found") + fake_socket = _FakeSocketForHttps(response_bytes) + created = [] + + def _fake_getaddrinfo(*_args, **_kwargs): + return [(0, 0, 0, "", ("1.1.1.1", 443))] + + def _fake_create_connection(address, *_args, **_kwargs): + created.append(address) + return fake_socket + + def _fake_wrap_socket(_context, sock, *args, **kwargs): + return sock + + import api.config as config + + monkeypatch.setattr(socket, "getaddrinfo", _fake_getaddrinfo) + monkeypatch.setattr(socket, "create_connection", _fake_create_connection) + monkeypatch.setattr(ssl.SSLContext, "wrap_socket", _fake_wrap_socket) + monkeypatch.setenv("OPENAI_API_KEY", "sk-openai") + monkeypatch.setattr(config, "get_config", lambda: { + "tts": {"openai": {"base_url": f"https://{host}/v1"}} + }) + h = _post({"text": "Hello", "engine": "openai"}, client="10.82.0.10") + routes._handle_tts(h, None) + + assert h.status in (500, 502) + assert created == [("1.1.1.1", 443)] + assert "OpenAI TTS generation failed" in (h.payload() or {}).get("error", "") + + +def test_openai_tts_ignores_https_proxy_and_dials_pinned_target(monkeypatch): + host = "proxy-safe-openai.example.com" + response_bytes = _http_response_bytes(200, b"audio-openai") + fake_socket = _FakeSocketForHttps(response_bytes) + created = [] + + def _fake_getaddrinfo(target_host, *_args, **_kwargs): + if target_host == host: + return [(0, 0, 0, "", ("1.1.1.1", 443))] + if target_host == "127.0.0.1": + return [(0, 0, 0, "", ("127.0.0.1", 8888))] + raise AssertionError(f"unexpected DNS lookup for {target_host}") + + def _fake_create_connection(address, *_args, **_kwargs): + created.append(address) + return fake_socket + + def _fake_wrap_socket(_context, sock, *args, **kwargs): + return sock + + import api.config as config + + monkeypatch.setattr(socket, "getaddrinfo", _fake_getaddrinfo) + monkeypatch.setattr(socket, "create_connection", _fake_create_connection) + monkeypatch.setattr(ssl.SSLContext, "wrap_socket", _fake_wrap_socket) + monkeypatch.setenv("OPENAI_API_KEY", "sk-openai") + monkeypatch.setenv("HTTPS_PROXY", "http://127.0.0.1:8888") + monkeypatch.setenv("https_proxy", "http://127.0.0.1:8888") + monkeypatch.setattr(config, "get_config", lambda: { + "tts": {"openai": {"base_url": f"https://{host}/v1"}} + }) + h = _post({"text": "Hello", "engine": "openai"}, client="10.82.0.11") + routes._handle_tts(h, None) + + assert h.status == 200 + assert h.wfile.getvalue() == b"audio-openai" + assert created == [("1.1.1.1", 443)] + + @pytest.mark.parametrize("base_url", [ "http://169.254.169.254/v1", "https://user:pass@api.example.com/v1",