diff --git a/VERSION.txt b/VERSION.txt index df83a51c6cb9a..27a0f1d276599 100644 --- a/VERSION.txt +++ b/VERSION.txt @@ -1 +1 @@ -1.94.1 +1.96.5 diff --git a/k8s-operator/api-proxy/proxy.go b/k8s-operator/api-proxy/proxy.go index f5f1da80f1a05..acc7b62341b83 100644 --- a/k8s-operator/api-proxy/proxy.go +++ b/k8s-operator/api-proxy/proxy.go @@ -1,4 +1,4 @@ -// Copyright (c) Tailscale Inc & AUTHORS +// Copyright (c) Tailscale Inc & contributors // SPDX-License-Identifier: BSD-3-Clause //go:build !plan9 @@ -21,6 +21,7 @@ import ( "strings" "time" + "github.com/pires/go-proxyproto" "go.uber.org/zap" "k8s.io/apimachinery/pkg/util/sets" "k8s.io/apiserver/pkg/endpoints/request" @@ -28,7 +29,6 @@ import ( "k8s.io/client-go/transport" "tailscale.com/client/local" "tailscale.com/client/tailscale/apitype" - "tailscale.com/envknob" ksr "tailscale.com/k8s-operator/sessionrecording" "tailscale.com/kube/kubetypes" "tailscale.com/net/netx" @@ -43,13 +43,7 @@ import ( var ( // counterNumRequestsproxies counts the number of API server requests proxied via this proxy. counterNumRequestsProxied = clientmetric.NewCounter("k8s_auth_proxy_requests_proxied") - // NOTE: adding this metric so we can keep track of users during deprecation - counterExperimentalEventsVarUsed = clientmetric.NewCounter("ts_experimental_kube_api_events_var_used") - whoIsKey = ctxkey.New("", (*apitype.WhoIsResponse)(nil)) -) - -const ( - eventsEnabledVar = "TS_EXPERIMENTAL_KUBE_API_EVENTS" + whoIsKey = ctxkey.New("", (*apitype.WhoIsResponse)(nil)) ) // NewAPIServerProxy creates a new APIServerProxy that's ready to start once Run @@ -103,7 +97,6 @@ func NewAPIServerProxy(zlog *zap.SugaredLogger, restConfig *rest.Config, ts *tsn upstreamURL: u, ts: ts, sendEventFunc: sessionrecording.SendEvent, - eventsEnabled: envknob.Bool(eventsEnabledVar), } ap.rp = &httputil.ReverseProxy{ Rewrite: func(pr *httputil.ProxyRequest) { @@ -134,11 +127,6 @@ func (ap *APIServerProxy) Run(ctx context.Context) error { TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), } - if ap.eventsEnabled { - counterExperimentalEventsVarUsed.Add(1) - ap.log.Warnf("DEPRECATED: %q environment variable is deprecated, and will be removed in v1.96. See documentation for more detail.", eventsEnabledVar) - } - mode := "noauth" if ap.authMode { mode = "auth" @@ -163,10 +151,18 @@ func (ap *APIServerProxy) Run(ctx context.Context) error { } } else { var err error - proxyLn, err = net.Listen("tcp", "localhost:80") + baseLn, err := net.Listen("tcp", "localhost:80") if err != nil { return fmt.Errorf("could not listen on :80: %w", err) } + proxyLn = &proxyproto.Listener{ + Listener: baseLn, + ReadHeaderTimeout: 10 * time.Second, + ConnPolicy: proxyproto.ConnPolicyFunc(func(opts proxyproto.ConnPolicyOptions) (proxyproto.Policy, + error) { + return proxyproto.REQUIRE, nil + }), + } serve = ap.hs.Serve } @@ -205,10 +201,6 @@ type APIServerProxy struct { upstreamURL *url.URL sendEventFunc func(ap netip.AddrPort, event io.Reader, dial netx.DialFunc) error - - // Flag used to enable sending API requests as events to tsrecorder. - // Deprecated: events are now set via ACLs (see https://tailscale.com/kb/1246/tailscale-ssh-session-recording#turn-on-session-recording-in-your-tailnet-policy-file) - eventsEnabled bool } // serveDefault is the default handler for Kubernetes API server requests. @@ -237,8 +229,7 @@ func (ap *APIServerProxy) serveDefault(w http.ResponseWriter, r *http.Request) { return } - // NOTE: (ChaosInTheCRD) ap.eventsEnabled deprecated, remove in v1.96 - if c.enableEvents || ap.eventsEnabled { + if c.enableEvents { if err = ap.recordRequestAsEvent(r, who, c.recorderAddresses, c.failOpen); err != nil { msg := fmt.Sprintf("error recording Kubernetes API request: %v", err) ap.log.Errorf(msg) @@ -308,8 +299,7 @@ func (ap *APIServerProxy) sessionForProto(w http.ResponseWriter, r *http.Request return } - // NOTE: (ChaosInTheCRD) ap.eventsEnabled deprecated, remove in v1.96 - if c.enableEvents || ap.eventsEnabled { + if c.enableEvents { if err = ap.recordRequestAsEvent(r, who, c.recorderAddresses, c.failOpen); err != nil { msg := fmt.Sprintf("error recording Kubernetes API request: %v", err) ap.log.Errorf(msg) diff --git a/net/netmon/interfaces.go b/net/netmon/interfaces.go index 4cf93973c6473..c7a2cb213e893 100644 --- a/net/netmon/interfaces.go +++ b/net/netmon/interfaces.go @@ -1,4 +1,4 @@ -// Copyright (c) Tailscale Inc & AUTHORS +// Copyright (c) Tailscale Inc & contributors // SPDX-License-Identifier: BSD-3-Clause package netmon diff --git a/tsnet/example/tsnet-services/tsnet-services.go b/tsnet/example/tsnet-services/tsnet-services.go index 6eb1a76ab5f5c..d72fd68fd412a 100644 --- a/tsnet/example/tsnet-services/tsnet-services.go +++ b/tsnet/example/tsnet-services/tsnet-services.go @@ -1,4 +1,4 @@ -// Copyright (c) Tailscale Inc & AUTHORS +// Copyright (c) Tailscale Inc & contributors // SPDX-License-Identifier: BSD-3-Clause // The tsnet-services example demonstrates how to use tsnet with Services. diff --git a/tsnet/example_tsnet_listen_service_multiple_ports_test.go b/tsnet/example_tsnet_listen_service_multiple_ports_test.go index 04781c2b20d16..5fe86a9ecf9fe 100644 --- a/tsnet/example_tsnet_listen_service_multiple_ports_test.go +++ b/tsnet/example_tsnet_listen_service_multiple_ports_test.go @@ -1,4 +1,4 @@ -// Copyright (c) Tailscale Inc & AUTHORS +// Copyright (c) Tailscale Inc & contributors // SPDX-License-Identifier: BSD-3-Clause package tsnet_test @@ -19,21 +19,19 @@ import ( // Service on multiple ports. In this example, we run an HTTPS server on 443 and // an HTTP server handling pprof requests to the same runtime on 6060. func ExampleServer_ListenService_multiplePorts() { - s := &tsnet.Server{ - Hostname: "tsnet-services-demo", + srv := &tsnet.Server{ + Hostname: "shu", } - defer s.Close() - ln, err := s.ListenService("svc:my-service", tsnet.ServiceModeHTTP{ + ln, err := srv.ListenService("svc:my-service", tsnet.ServiceModeHTTP{ HTTPS: true, Port: 443, }) if err != nil { log.Fatal(err) } - defer ln.Close() - pprofLn, err := s.ListenService("svc:my-service", tsnet.ServiceModeTCP{ + pprofLn, err := srv.ListenService("svc:my-service", tsnet.ServiceModeTCP{ Port: 6060, }) if err != nil { diff --git a/tsnet/example_tsnet_test.go b/tsnet/example_tsnet_test.go index 2a3236b3b6501..2af31a76f787f 100644 --- a/tsnet/example_tsnet_test.go +++ b/tsnet/example_tsnet_test.go @@ -1,4 +1,4 @@ -// Copyright (c) Tailscale Inc & AUTHORS +// Copyright (c) Tailscale Inc & contributors // SPDX-License-Identifier: BSD-3-Clause package tsnet_test @@ -205,19 +205,17 @@ func ExampleServer_ListenFunnel_funnelOnly() { // ExampleServer_ListenService demonstrates how to advertise an HTTPS Service. func ExampleServer_ListenService() { - s := &tsnet.Server{ - Hostname: "tsnet-services-demo", + srv := &tsnet.Server{ + Hostname: "atum", } - defer s.Close() - ln, err := s.ListenService("svc:my-service", tsnet.ServiceModeHTTP{ + ln, err := srv.ListenService("svc:my-service", tsnet.ServiceModeHTTP{ HTTPS: true, Port: 443, }) if err != nil { log.Fatal(err) } - defer ln.Close() log.Printf("Listening on https://%v\n", ln.FQDN) log.Fatal(http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -238,19 +236,17 @@ func ExampleServer_ListenService_reverseProxy() { Host: targetAddress, }) - s := &tsnet.Server{ - Hostname: "tsnet-services-demo", + srv := &tsnet.Server{ + Hostname: "tefnut", } - defer s.Close() - ln, err := s.ListenService("svc:my-service", tsnet.ServiceModeHTTP{ + ln, err := srv.ListenService("svc:my-service", tsnet.ServiceModeHTTP{ HTTPS: true, Port: 443, }) if err != nil { log.Fatal(err) } - defer ln.Close() log.Printf("Listening on https://%v\n", ln.FQDN) log.Fatal(http.Serve(ln, reverseProxy)) diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index 6c840c335535e..776854e227926 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -1,4 +1,4 @@ -// Copyright (c) Tailscale Inc & AUTHORS +// Copyright (c) Tailscale Inc & contributors // SPDX-License-Identifier: BSD-3-Clause // Package tsnet provides Tailscale as a library. @@ -26,6 +26,7 @@ import ( "sync" "time" + "github.com/tailscale/wireguard-go/tun" "tailscale.com/client/local" "tailscale.com/control/controlclient" "tailscale.com/envknob" @@ -167,6 +168,11 @@ type Server struct { // that the control server will allow the node to adopt that tag. AdvertiseTags []string + // Tun, if non-nil, specifies a custom tun.Device to use for packet I/O. + // + // This field must be set before calling Start. + Tun tun.Device + initOnce sync.Once initErr error lb *ipnlocal.LocalBackend @@ -190,9 +196,10 @@ type Server struct { mu sync.Mutex listeners map[listenKey]*listener + nextEphemeralPort uint16 // next port to try in ephemeral range; 0 means use ephemeralPortFirst fallbackTCPHandlers set.HandleSet[FallbackTCPHandler] dialer *tsdial.Dialer - closed bool + closeOnce sync.Once } // FallbackTCPHandler describes the callback which @@ -433,11 +440,29 @@ func (s *Server) Up(ctx context.Context) (*ipnstate.Status, error) { // // It must not be called before or concurrently with Start. func (s *Server) Close() error { - s.mu.Lock() - defer s.mu.Unlock() - if s.closed { + didClose := false + s.closeOnce.Do(func() { + didClose = true + s.close() + }) + if !didClose { return fmt.Errorf("tsnet: %w", net.ErrClosed) } + return nil +} + +func (s *Server) close() { + // Close listeners under s.mu, then release before the heavy shutdown + // operations. We must not hold s.mu during netstack.Close, lb.Shutdown, + // etc. because callbacks from gVisor (e.g. getTCPHandlerForFlow) + // acquire s.mu, and waiting for those goroutines while holding the lock + // would deadlock. + s.mu.Lock() + for _, ln := range s.listeners { + ln.closeLocked() + } + s.mu.Unlock() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() var wg sync.WaitGroup @@ -460,13 +485,12 @@ func (s *Server) Close() error { } }() - if s.netstack != nil { - s.netstack.Close() - s.netstack = nil - } if s.shutdownCancel != nil { s.shutdownCancel() } + if s.netstack != nil { + s.netstack.Close() + } if s.lb != nil { s.lb.Shutdown() } @@ -483,13 +507,8 @@ func (s *Server) Close() error { s.loopbackListener.Close() } - for _, ln := range s.listeners { - ln.closeLocked() - } wg.Wait() s.sys.Bus.Get().Close() - s.closed = true - return nil } func (s *Server) doInit() { @@ -659,6 +678,7 @@ func (s *Server) start() (reterr error) { s.dialer = &tsdial.Dialer{Logf: tsLogf} // mutated below (before used) s.dialer.SetBus(sys.Bus.Get()) eng, err := wgengine.NewUserspaceEngine(tsLogf, wgengine.Config{ + Tun: s.Tun, EventBus: sys.Bus.Get(), ListenPort: s.Port, NetMon: s.netMon, @@ -682,8 +702,16 @@ func (s *Server) start() (reterr error) { } sys.Tun.Get().Start() sys.Set(ns) - ns.ProcessLocalIPs = true - ns.ProcessSubnets = true + if s.Tun == nil { + // Only process packets in netstack when using the default fake TUN. + // When a TUN is provided, let packets flow through it instead. + ns.ProcessLocalIPs = true + ns.ProcessSubnets = true + } else { + // When using a TUN, check gVisor for registered endpoints to handle + // packets for tsnet listeners and outbound connection replies. + ns.CheckLocalTransportEndpoints = true + } ns.GetTCPHandlerForFlow = s.getTCPHandlerForFlow ns.GetUDPHandlerForFlow = s.getUDPHandlerForFlow s.netstack = ns @@ -1075,7 +1103,42 @@ func (s *Server) ListenPacket(network, addr string) (net.PacketConn, error) { if err := s.Start(); err != nil { return nil, err } - return s.netstack.ListenPacket(network, ap.String()) + + // Create the gVisor PacketConn first so it can handle port 0 allocation. + pc, err := s.netstack.ListenPacket(network, ap.String()) + if err != nil { + return nil, err + } + + // If port 0 was requested, use the port gVisor assigned. + if ap.Port() == 0 { + if p := portFromAddr(pc.LocalAddr()); p != 0 { + ap = netip.AddrPortFrom(ap.Addr(), p) + addr = ap.String() + } + } + + ln, err := s.registerListener(network, addr, ap, listenOnTailnet, nil) + if err != nil { + pc.Close() + return nil, err + } + + return &udpPacketConn{ + PacketConn: pc, + ln: ln, + }, nil +} + +// udpPacketConn wraps a net.PacketConn to unregister from s.listeners on Close. +type udpPacketConn struct { + net.PacketConn + ln *listener +} + +func (c *udpPacketConn) Close() error { + c.ln.Close() + return c.PacketConn.Close() } // ListenTLS announces only on the Tailscale network. @@ -1410,6 +1473,8 @@ var ErrUntaggedServiceHost = errors.New("service hosts must be tagged nodes") // To advertise a Service with multiple ports, run ListenService multiple times. // For more information about Services, see // https://tailscale.com/kb/1552/tailscale-services +// +// This function will start the server if it is not already started. func (s *Server) ListenService(name string, mode ServiceMode) (*ServiceListener, error) { if err := tailcfg.ServiceName(name).Validate(); err != nil { return nil, err @@ -1568,6 +1633,11 @@ func resolveListenAddr(network, addr string) (netip.AddrPort, error) { if err != nil { return zero, fmt.Errorf("invalid Listen addr %q; host part must be empty or IP literal", host) } + // Normalize unspecified addresses (0.0.0.0, ::) to the zero value, + // equivalent to an empty host, so they match the node's own IPs. + if bindHostOrZero.IsUnspecified() { + return netip.AddrPortFrom(netip.Addr{}, uint16(port)), nil + } if strings.HasSuffix(network, "4") && !bindHostOrZero.Is4() { return zero, fmt.Errorf("invalid non-IPv4 addr %v for network %q", host, network) } @@ -1577,6 +1647,17 @@ func resolveListenAddr(network, addr string) (netip.AddrPort, error) { return netip.AddrPortFrom(bindHostOrZero, uint16(port)), nil } +// ephemeral port range for non-TUN listeners requesting port 0. This range is +// chosen to reduce the probability of collision with host listeners, avoiding +// both the typical ephemeral range, and privilege listener ranges. Collisions +// may still occur and could for example shadow host sockets in a netstack+TUN +// situation, the range here is a UX improvement, not a guarantee that +// application authors will never have to consider these cases. +const ( + ephemeralPortFirst = 10002 + ephemeralPortLast = 19999 +) + func (s *Server) listen(network, addr string, lnOn listenOn) (net.Listener, error) { switch network { case "", "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6": @@ -1590,6 +1671,76 @@ func (s *Server) listen(network, addr string, lnOn listenOn) (net.Listener, erro if err := s.Start(); err != nil { return nil, err } + + isTCP := network == "" || network == "tcp" || network == "tcp4" || network == "tcp6" + + // When using a TUN with TCP, create a gVisor TCP listener. + // gVisor handles port 0 allocation natively. + var gonetLn net.Listener + if s.Tun != nil && isTCP { + gonetLn, err = s.listenTCP(network, host) + if err != nil { + return nil, err + } + // If port 0 was requested, update host to the port gVisor assigned + // so that the listenKey uses the real port. + if host.Port() == 0 { + if p := portFromAddr(gonetLn.Addr()); p != 0 { + host = netip.AddrPortFrom(host.Addr(), p) + addr = listenAddr(host) + } + } + } + + ln, err := s.registerListener(network, addr, host, lnOn, gonetLn) + if err != nil { + if gonetLn != nil { + gonetLn.Close() + } + return nil, err + } + return ln, nil +} + +// listenTCP creates a gVisor TCP listener for TUN mode. +func (s *Server) listenTCP(network string, host netip.AddrPort) (net.Listener, error) { + var nsNetwork string + nsAddr := host + switch { + case network == "tcp4" || network == "tcp6": + nsNetwork = network + case host.Addr().Is4(): + nsNetwork = "tcp4" + case host.Addr().Is6(): + nsNetwork = "tcp6" + default: + // Wildcard address: use tcp6 for dual-stack (accepts both v4 and v6). + nsNetwork = "tcp6" + nsAddr = netip.AddrPortFrom(netip.IPv6Unspecified(), host.Port()) + } + ln, err := s.netstack.ListenTCP(nsNetwork, nsAddr.String()) + if err != nil { + return nil, fmt.Errorf("tsnet: %w", err) + } + return ln, nil +} + +// registerListener allocates a port (if 0) and registers the listener in +// s.listeners under s.mu. +func (s *Server) registerListener(network, addr string, host netip.AddrPort, lnOn listenOn, gonetLn net.Listener) (*listener, error) { + s.mu.Lock() + defer s.mu.Unlock() + + // Allocate an ephemeral port for non-TUN listeners requesting port 0. + if host.Port() == 0 && gonetLn == nil { + p, ok := s.allocEphemeralLocked(network, host.Addr(), lnOn) + if !ok { + return nil, errors.New("tsnet: no available port in ephemeral range") + } + host = netip.AddrPortFrom(host.Addr(), p) + addr = listenAddr(host) + } + var keys []listenKey switch lnOn { case listenOnTailnet: @@ -1601,31 +1752,93 @@ func (s *Server) listen(network, addr string, lnOn listenOn) (net.Listener, erro keys = append(keys, listenKey{network, host.Addr(), host.Port(), true}) } - ln := &listener{ - s: s, - keys: keys, - addr: addr, - - closedc: make(chan struct{}), - conn: make(chan net.Conn), - } - s.mu.Lock() for _, key := range keys { if _, ok := s.listeners[key]; ok { - s.mu.Unlock() return nil, fmt.Errorf("tsnet: listener already open for %s, %s", network, addr) } } + + ln := &listener{ + s: s, + keys: keys, + addr: addr, + closedc: make(chan struct{}), + conn: make(chan net.Conn), + gonetLn: gonetLn, + } if s.listeners == nil { s.listeners = make(map[listenKey]*listener) } for _, key := range keys { s.listeners[key] = ln } - s.mu.Unlock() return ln, nil } +// allocEphemeralLocked finds an unused port in [ephemeralPortFirst, +// ephemeralPortLast] that does not collide with any existing listener for the +// given network, host, and listenOn. s.mu must be held. +func (s *Server) allocEphemeralLocked(network string, host netip.Addr, lnOn listenOn) (uint16, bool) { + if s.nextEphemeralPort < ephemeralPortFirst || s.nextEphemeralPort > ephemeralPortLast { + s.nextEphemeralPort = ephemeralPortFirst + } + start := s.nextEphemeralPort + for { + p := s.nextEphemeralPort + s.nextEphemeralPort++ + if s.nextEphemeralPort > ephemeralPortLast { + s.nextEphemeralPort = ephemeralPortFirst + } + if !s.portInUseLocked(network, host, p, lnOn) { + return p, true + } + if s.nextEphemeralPort == start { + return 0, false + } + } +} + +// portInUseLocked reports whether any listenKey for the given network, host, +// port, and listenOn already exists in s.listeners. +func (s *Server) portInUseLocked(network string, host netip.Addr, port uint16, lnOn listenOn) bool { + switch lnOn { + case listenOnTailnet: + _, ok := s.listeners[listenKey{network, host, port, false}] + return ok + case listenOnFunnel: + _, ok := s.listeners[listenKey{network, host, port, true}] + return ok + case listenOnBoth: + _, ok1 := s.listeners[listenKey{network, host, port, false}] + _, ok2 := s.listeners[listenKey{network, host, port, true}] + return ok1 || ok2 + } + return false +} + +// listenAddr formats host as a listen address string. +// If host has no IP, it returns ":port". +func listenAddr(host netip.AddrPort) string { + if !host.Addr().IsValid() { + return ":" + strconv.Itoa(int(host.Port())) + } + return host.String() +} + +// portFromAddr extracts the port from a net.Addr, or returns 0. +func portFromAddr(a net.Addr) uint16 { + switch v := a.(type) { + case *net.TCPAddr: + return uint16(v.Port) + case *net.UDPAddr: + return uint16(v.Port) + } + if ap, err := netip.ParseAddrPort(a.String()); err == nil { + return ap.Port() + } + return 0 +} + // GetRootPath returns the root path of the tsnet server. // This is where the state file and other data is stored. func (s *Server) GetRootPath() string { @@ -1682,9 +1895,17 @@ type listener struct { conn chan net.Conn // unbuffered, never closed closedc chan struct{} // closed on [listener.Close] closed bool // guarded by s.mu + + // gonetLn, if set, is the gonet.Listener that handles new connections. + // gonetLn is set by [listen] when a TUN is in use and terminates the listener. + // gonetLn is nil when TUN is nil. + gonetLn net.Listener } func (ln *listener) Accept() (net.Conn, error) { + if ln.gonetLn != nil { + return ln.gonetLn.Accept() + } select { case c := <-ln.conn: return c, nil @@ -1694,6 +1915,9 @@ func (ln *listener) Accept() (net.Conn, error) { } func (ln *listener) Addr() net.Addr { + if ln.gonetLn != nil { + return ln.gonetLn.Addr() + } return addr{ network: ln.keys[0].network, addr: ln.addr, @@ -1719,6 +1943,9 @@ func (ln *listener) closeLocked() error { } close(ln.closedc) ln.closed = true + if ln.gonetLn != nil { + ln.gonetLn.Close() + } return nil } diff --git a/tsnet/tsnet_test.go b/tsnet/tsnet_test.go index f44bacab08431..1cf4bf48fe5bd 100644 --- a/tsnet/tsnet_test.go +++ b/tsnet/tsnet_test.go @@ -1,4 +1,4 @@ -// Copyright (c) Tailscale Inc & AUTHORS +// Copyright (c) Tailscale Inc & contributors // SPDX-License-Identifier: BSD-3-Clause package tsnet @@ -39,6 +39,7 @@ import ( "github.com/google/go-cmp/cmp" dto "github.com/prometheus/client_model/go" "github.com/prometheus/common/expfmt" + "github.com/tailscale/wireguard-go/tun" "golang.org/x/net/proxy" "tailscale.com/client/local" @@ -48,11 +49,13 @@ import ( "tailscale.com/ipn/ipnlocal" "tailscale.com/ipn/store/mem" "tailscale.com/net/netns" + "tailscale.com/net/packet" "tailscale.com/tailcfg" "tailscale.com/tstest" "tailscale.com/tstest/deptest" "tailscale.com/tstest/integration" "tailscale.com/tstest/integration/testcontrol" + "tailscale.com/types/ipproto" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/views" @@ -109,6 +112,86 @@ func TestListenerPort(t *testing.T) { } } +func TestResolveListenAddrUnspecified(t *testing.T) { + tests := []struct { + name string + network string + addr string + wantIP netip.Addr + }{ + {"empty_host", "tcp", ":80", netip.Addr{}}, + {"ipv4_unspecified", "tcp", "0.0.0.0:80", netip.Addr{}}, + {"ipv6_unspecified", "tcp", "[::]:80", netip.Addr{}}, + {"specific_ipv4", "tcp", "100.64.0.1:80", netip.MustParseAddr("100.64.0.1")}, + {"specific_ipv6", "tcp6", "[fd7a:115c:a1e0::1]:80", netip.MustParseAddr("fd7a:115c:a1e0::1")}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := resolveListenAddr(tt.network, tt.addr) + if err != nil { + t.Fatal(err) + } + if got.Addr() != tt.wantIP { + t.Errorf("Addr() = %v, want %v", got.Addr(), tt.wantIP) + } + }) + } +} + +func TestAllocEphemeral(t *testing.T) { + s := &Server{listeners: make(map[listenKey]*listener)} + + // Sequential allocations should return unique ports in range. + var ports []uint16 + for range 5 { + s.mu.Lock() + p, ok := s.allocEphemeralLocked("tcp", netip.Addr{}, listenOnTailnet) + s.mu.Unlock() + if !ok { + t.Fatal("allocEphemeralLocked failed unexpectedly") + } + if p < ephemeralPortFirst || p > ephemeralPortLast { + t.Errorf("port %d outside [%d, %d]", p, ephemeralPortFirst, ephemeralPortLast) + } + for _, prev := range ports { + if p == prev { + t.Errorf("duplicate port %d", p) + } + } + ports = append(ports, p) + // Occupy the port so the next call skips it. + s.listeners[listenKey{"tcp", netip.Addr{}, p, false}] = &listener{} + } + + // Verify skip over occupied port. + s.mu.Lock() + next := s.nextEphemeralPort + if next < ephemeralPortFirst || next > ephemeralPortLast { + next = ephemeralPortFirst + } + s.listeners[listenKey{"tcp", netip.Addr{}, next, false}] = &listener{} + p, ok := s.allocEphemeralLocked("tcp", netip.Addr{}, listenOnTailnet) + s.mu.Unlock() + if !ok { + t.Fatal("allocEphemeralLocked failed after skip") + } + if p == next { + t.Errorf("should have skipped occupied port %d", next) + } + + // Wrap-around. + s.mu.Lock() + s.nextEphemeralPort = ephemeralPortLast + p, ok = s.allocEphemeralLocked("tcp", netip.Addr{}, listenOnTailnet) + s.mu.Unlock() + if !ok { + t.Fatal("allocEphemeralLocked failed at wrap") + } + if p < ephemeralPortFirst || p > ephemeralPortLast { + t.Errorf("port %d outside range after wrap", p) + } +} + var verboseDERP = flag.Bool("verbose-derp", false, "if set, print DERP and STUN logs") var verboseNodes = flag.Bool("verbose-nodes", false, "if set, print tsnet.Server logs") @@ -362,7 +445,7 @@ func TestConn(t *testing.T) { for { c, err := ln.Accept() if err != nil { - if ctx.Err() != nil { + if ctx.Err() != nil || errors.Is(err, net.ErrClosed) { return } t.Errorf("s1.Accept: %v", err) @@ -1138,83 +1221,89 @@ func TestListenService(t *testing.T) { // This ends up also testing the Service forwarding logic in // LocalBackend, but that's useful too. t.Run(tt.name, func(t *testing.T) { - ctx := t.Context() - - controlURL, control := startControl(t) - serviceHost, _, _ := startServer(t, ctx, controlURL, "service-host") - serviceClient, _, _ := startServer(t, ctx, controlURL, "service-client") - - const serviceName = tailcfg.ServiceName("svc:foo") - const serviceVIP = "100.11.22.33" - - // == Set up necessary state in our mock == - - // The Service host must have the 'service-host' capability, which - // is a mapping from the Service name to the Service VIP. - var serviceHostCaps map[tailcfg.ServiceName]views.Slice[netip.Addr] - mak.Set(&serviceHostCaps, serviceName, views.SliceOf([]netip.Addr{netip.MustParseAddr(serviceVIP)})) - j := must.Get(json.Marshal(serviceHostCaps)) - cm := serviceHost.lb.NetMap().SelfNode.CapMap().AsMap() - mak.Set(&cm, tailcfg.NodeAttrServiceHost, []tailcfg.RawMessage{tailcfg.RawMessage(j)}) - control.SetNodeCapMap(serviceHost.lb.NodeKey(), cm) - - // The Service host must be allowed to advertise the Service VIP. - control.SetSubnetRoutes(serviceHost.lb.NodeKey(), []netip.Prefix{ - netip.MustParsePrefix(serviceVIP + `/32`), - }) - - // The Service host must be a tagged node (any tag will do). - serviceHostNode := control.Node(serviceHost.lb.NodeKey()) - serviceHostNode.Tags = append(serviceHostNode.Tags, "some-tag") - control.UpdateNode(serviceHostNode) - - // The service client must accept routes advertised by other nodes - // (RouteAll is equivalent to --accept-routes). - must.Get(serviceClient.localClient.EditPrefs(ctx, &ipn.MaskedPrefs{ - RouteAllSet: true, - Prefs: ipn.Prefs{ - RouteAll: true, - }, - })) - - // Set up DNS for our Service. - control.AddDNSRecords(tailcfg.DNSRecord{ - Name: serviceName.WithoutPrefix() + "." + control.MagicDNSDomain, - Value: serviceVIP, - }) + // We run each test with and without a TUN device ([Server.Tun]). + // Note that this TUN device is distinct from TUN mode for Services. + doTest := func(t *testing.T, withTUNDevice bool) { + ctx := t.Context() + + lt := setupTwoClientTest(t, withTUNDevice) + serviceHost := lt.s2 + serviceClient := lt.s1 + control := lt.control + + const serviceName = tailcfg.ServiceName("svc:foo") + const serviceVIP = "100.11.22.33" + + // == Set up necessary state in our mock == + + // The Service host must have the 'service-host' capability, which + // is a mapping from the Service name to the Service VIP. + var serviceHostCaps map[tailcfg.ServiceName]views.Slice[netip.Addr] + mak.Set(&serviceHostCaps, serviceName, views.SliceOf([]netip.Addr{netip.MustParseAddr(serviceVIP)})) + j := must.Get(json.Marshal(serviceHostCaps)) + cm := serviceHost.lb.NetMap().SelfNode.CapMap().AsMap() + mak.Set(&cm, tailcfg.NodeAttrServiceHost, []tailcfg.RawMessage{tailcfg.RawMessage(j)}) + control.SetNodeCapMap(serviceHost.lb.NodeKey(), cm) + + // The Service host must be allowed to advertise the Service VIP. + control.SetSubnetRoutes(serviceHost.lb.NodeKey(), []netip.Prefix{ + netip.MustParsePrefix(serviceVIP + `/32`), + }) - if tt.extraSetup != nil { - tt.extraSetup(t, control) - } + // The Service host must be a tagged node (any tag will do). + serviceHostNode := control.Node(serviceHost.lb.NodeKey()) + serviceHostNode.Tags = append(serviceHostNode.Tags, "some-tag") + control.UpdateNode(serviceHostNode) + + // The service client must accept routes advertised by other nodes + // (RouteAll is equivalent to --accept-routes). + must.Get(serviceClient.localClient.EditPrefs(ctx, &ipn.MaskedPrefs{ + RouteAllSet: true, + Prefs: ipn.Prefs{ + RouteAll: true, + }, + })) - // Force netmap updates to avoid race conditions. The nodes need to - // see our control updates before we can start the test. - must.Do(control.ForceNetmapUpdate(ctx, serviceHost.lb.NodeKey())) - must.Do(control.ForceNetmapUpdate(ctx, serviceClient.lb.NodeKey())) - netmapUpToDate := func(s *Server) bool { - nm := s.lb.NetMap() - return slices.ContainsFunc(nm.DNS.ExtraRecords, func(r tailcfg.DNSRecord) bool { - return r.Value == serviceVIP + // Set up DNS for our Service. + control.AddDNSRecords(tailcfg.DNSRecord{ + Name: serviceName.WithoutPrefix() + "." + control.MagicDNSDomain, + Value: serviceVIP, }) - } - for !netmapUpToDate(serviceClient) { - time.Sleep(10 * time.Millisecond) - } - for !netmapUpToDate(serviceHost) { - time.Sleep(10 * time.Millisecond) - } - // == Done setting up mock state == + if tt.extraSetup != nil { + tt.extraSetup(t, control) + } - // Start the Service listeners. - listeners := make([]*ServiceListener, 0, len(tt.modes)) - for _, input := range tt.modes { - ln := must.Get(serviceHost.ListenService(serviceName.String(), input)) - defer ln.Close() - listeners = append(listeners, ln) + // Wait until both nodes have up-to-date netmaps before + // proceeding with the test. + netmapUpToDate := func(s *Server) bool { + nm := s.lb.NetMap() + return slices.ContainsFunc(nm.DNS.ExtraRecords, func(r tailcfg.DNSRecord) bool { + return r.Value == serviceVIP + }) + } + for !netmapUpToDate(serviceClient) { + time.Sleep(10 * time.Millisecond) + } + for !netmapUpToDate(serviceHost) { + time.Sleep(10 * time.Millisecond) + } + + // == Done setting up mock state == + + // Start the Service listeners. + listeners := make([]*ServiceListener, 0, len(tt.modes)) + for _, input := range tt.modes { + ln := must.Get(serviceHost.ListenService(serviceName.String(), input)) + defer ln.Close() + listeners = append(listeners, ln) + } + + tt.run(t, listeners, serviceClient) } - tt.run(t, listeners, serviceClient) + t.Run("TUN", func(t *testing.T) { doTest(t, true) }) + t.Run("netstack", func(t *testing.T) { doTest(t, false) }) }) } } @@ -1860,6 +1949,683 @@ func mustDirect(t *testing.T, logf logger.Logf, lc1, lc2 *local.Client) { t.Error("magicsock did not find a direct path from lc1 to lc2") } +// chanTUN is a tun.Device for testing that uses channels for packet I/O. +// Inbound receives packets written to the TUN (from the perspective of the network stack). +// Outbound is for injecting packets to be read from the TUN. +type chanTUN struct { + Inbound chan []byte // packets written to TUN + Outbound chan []byte // packets to read from TUN + closed chan struct{} + events chan tun.Event +} + +func newChanTUN() *chanTUN { + t := &chanTUN{ + Inbound: make(chan []byte, 1024), + Outbound: make(chan []byte, 1024), + closed: make(chan struct{}), + events: make(chan tun.Event, 1), + } + t.events <- tun.EventUp + return t +} + +func (t *chanTUN) File() *os.File { panic("not implemented") } + +func (t *chanTUN) Close() error { + select { + case <-t.closed: + default: + close(t.closed) + close(t.Inbound) + } + return nil +} + +func (t *chanTUN) Read(bufs [][]byte, sizes []int, offset int) (int, error) { + select { + case <-t.closed: + return 0, io.EOF + case pkt := <-t.Outbound: + sizes[0] = copy(bufs[0][offset:], pkt) + return 1, nil + } +} + +func (t *chanTUN) Write(bufs [][]byte, offset int) (int, error) { + for _, buf := range bufs { + pkt := buf[offset:] + if len(pkt) == 0 { + continue + } + select { + case <-t.closed: + return 0, errors.New("closed") + case t.Inbound <- slices.Clone(pkt): + default: + // Drop the packet if the channel is full, like a real + // TUN under congestion. This avoids blocking the + // WireGuard send path when no goroutine is draining. + } + } + return len(bufs), nil +} + +func (t *chanTUN) MTU() (int, error) { return 1280, nil } +func (t *chanTUN) Name() (string, error) { return "chantun", nil } +func (t *chanTUN) Events() <-chan tun.Event { return t.events } +func (t *chanTUN) BatchSize() int { return 1 } + +// listenTest provides common setup for listener and TUN tests. +type listenTest struct { + control *testcontrol.Server + s1, s2 *Server + s1ip4, s1ip6 netip.Addr + s2ip4, s2ip6 netip.Addr + tun *chanTUN // nil for netstack mode +} + +// setupTwoClientTest creates two tsnet servers for testing. +// If useTUN is true, s2 uses a chanTUN; otherwise it uses netstack only. +func setupTwoClientTest(t *testing.T, useTUN bool) *listenTest { + t.Helper() + tstest.Shard(t) + tstest.ResourceCheck(t) + ctx := t.Context() + controlURL, control := startControl(t) + s1, _, _ := startServer(t, ctx, controlURL, "s1") + + tmp := filepath.Join(t.TempDir(), "s2") + must.Do(os.MkdirAll(tmp, 0755)) + s2 := &Server{ + Dir: tmp, + ControlURL: controlURL, + Hostname: "s2", + Store: new(mem.Store), + Ephemeral: true, + } + + var tun *chanTUN + if useTUN { + tun = newChanTUN() + s2.Tun = tun + } + + if *verboseNodes { + s2.Logf = t.Logf + } + t.Cleanup(func() { s2.Close() }) + + s2status, err := s2.Up(ctx) + if err != nil { + t.Fatal(err) + } + s2.lb.ConfigureCertsForTest(testCertRoot.getCert) + + s1ip4, s1ip6 := s1.TailscaleIPs() + s2ip4 := s2status.TailscaleIPs[0] + var s2ip6 netip.Addr + if len(s2status.TailscaleIPs) > 1 { + s2ip6 = s2status.TailscaleIPs[1] + } + + lc1 := must.Get(s1.LocalClient()) + must.Get(lc1.Ping(ctx, s2ip4, tailcfg.PingTSMP)) + + return &listenTest{ + control: control, + s1: s1, + s2: s2, + s1ip4: s1ip4, + s1ip6: s1ip6, + s2ip4: s2ip4, + s2ip6: s2ip6, + tun: tun, + } +} + +// echoUDP returns an IP packet with src/dst and ports swapped, with checksums recomputed. +func echoUDP(pkt []byte) []byte { + var p packet.Parsed + p.Decode(pkt) + if p.IPProto != ipproto.UDP { + return nil + } + switch p.IPVersion { + case 4: + h := p.UDP4Header() + h.ToResponse() + return packet.Generate(h, p.Payload()) + case 6: + h := packet.UDP6Header{ + IP6Header: p.IP6Header(), + SrcPort: p.Src.Port(), + DstPort: p.Dst.Port(), + } + h.ToResponse() + return packet.Generate(h, p.Payload()) + } + return nil +} + +func TestTUN(t *testing.T) { + tt := setupTwoClientTest(t, true) + + go func() { + for pkt := range tt.tun.Inbound { + var p packet.Parsed + p.Decode(pkt) + if p.Dst.Port() == 9999 { + tt.tun.Outbound <- echoUDP(pkt) + } + } + }() + + test := func(t *testing.T, s2ip netip.Addr) { + conn, err := tt.s1.Dial(t.Context(), "udp", netip.AddrPortFrom(s2ip, 9999).String()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + want := "hello from s1" + if _, err := conn.Write([]byte(want)); err != nil { + t.Fatal(err) + } + + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + got := make([]byte, 1024) + n, err := conn.Read(got) + if err != nil { + t.Fatalf("reading echo response: %v", err) + } + if string(got[:n]) != want { + t.Errorf("got %q, want %q", got[:n], want) + } + } + + t.Run("IPv4", func(t *testing.T) { test(t, tt.s2ip4) }) + t.Run("IPv6", func(t *testing.T) { test(t, tt.s2ip6) }) +} + +// TestTUNDNS tests that a TUN can send DNS queries to quad-100 and receive +// responses. This verifies that handleLocalPackets intercepts outbound traffic +// to the service IP. +func TestTUNDNS(t *testing.T) { + tt := setupTwoClientTest(t, true) + + test := func(t *testing.T, srcIP netip.Addr, serviceIP netip.Addr) { + tt.tun.Outbound <- buildDNSQuery("s2", srcIP) + + ipVersion := uint8(4) + if srcIP.Is6() { + ipVersion = 6 + } + for { + select { + case pkt := <-tt.tun.Inbound: + var p packet.Parsed + p.Decode(pkt) + if p.IPVersion != ipVersion || p.IPProto != ipproto.UDP { + continue + } + if p.Src.Addr() == serviceIP && p.Src.Port() == 53 { + if len(p.Payload()) < 12 { + t.Fatalf("DNS response too short: %d bytes", len(p.Payload())) + } + return // success + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for DNS response") + } + } + } + + t.Run("IPv4", func(t *testing.T) { + test(t, tt.s2ip4, netip.MustParseAddr("100.100.100.100")) + }) + t.Run("IPv6", func(t *testing.T) { + test(t, tt.s2ip6, netip.MustParseAddr("fd7a:115c:a1e0::53")) + }) +} + +// TestListenPacket tests UDP listeners (ListenPacket) in both netstack and TUN modes. +func TestListenPacket(t *testing.T) { + testListenPacket := func(t *testing.T, lt *listenTest, listenIP netip.Addr) { + pc, err := lt.s2.ListenPacket("udp", netip.AddrPortFrom(listenIP, 0).String()) + if err != nil { + t.Fatal(err) + } + defer pc.Close() + + echoErr := make(chan error, 1) + go func() { + buf := make([]byte, 1500) + n, addr, err := pc.ReadFrom(buf) + if err != nil { + echoErr <- err + return + } + _, err = pc.WriteTo(buf[:n], addr) + if err != nil { + echoErr <- err + return + } + }() + + conn, err := lt.s1.Dial(t.Context(), "udp", pc.LocalAddr().String()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + want := "hello udp" + if _, err := conn.Write([]byte(want)); err != nil { + t.Fatal(err) + } + + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + got := make([]byte, 1024) + n, err := conn.Read(got) + if err != nil { + select { + case e := <-echoErr: + t.Fatalf("echo error: %v; read error: %v", e, err) + default: + t.Fatalf("Read failed: %v", err) + } + } + + if string(got[:n]) != want { + t.Errorf("got %q, want %q", got[:n], want) + } + } + + t.Run("Netstack", func(t *testing.T) { + lt := setupTwoClientTest(t, false) + t.Run("IPv4", func(t *testing.T) { testListenPacket(t, lt, lt.s2ip4) }) + t.Run("IPv6", func(t *testing.T) { testListenPacket(t, lt, lt.s2ip6) }) + }) + + t.Run("TUN", func(t *testing.T) { + lt := setupTwoClientTest(t, true) + t.Run("IPv4", func(t *testing.T) { testListenPacket(t, lt, lt.s2ip4) }) + t.Run("IPv6", func(t *testing.T) { testListenPacket(t, lt, lt.s2ip6) }) + }) +} + +// TestListenTCP tests TCP listeners with concrete addresses in both netstack +// and TUN modes. +func TestListenTCP(t *testing.T) { + testListenTCP := func(t *testing.T, lt *listenTest, listenIP netip.Addr) { + ln, err := lt.s2.Listen("tcp", netip.AddrPortFrom(listenIP, 0).String()) + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + echoErr := make(chan error, 1) + go func() { + conn, err := ln.Accept() + if err != nil { + echoErr <- err + return + } + defer conn.Close() + buf := make([]byte, 1024) + n, err := conn.Read(buf) + if err != nil { + echoErr <- err + return + } + _, err = conn.Write(buf[:n]) + if err != nil { + echoErr <- err + return + } + }() + + conn, err := lt.s1.Dial(t.Context(), "tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + defer conn.Close() + + want := "hello tcp" + if _, err := conn.Write([]byte(want)); err != nil { + t.Fatalf("Write failed: %v", err) + } + + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + got := make([]byte, 1024) + n, err := conn.Read(got) + if err != nil { + select { + case e := <-echoErr: + t.Fatalf("echo error: %v; read error: %v", e, err) + default: + t.Fatalf("Read failed: %v", err) + } + } + + if string(got[:n]) != want { + t.Errorf("got %q, want %q", got[:n], want) + } + } + + t.Run("Netstack", func(t *testing.T) { + lt := setupTwoClientTest(t, false) + t.Run("IPv4", func(t *testing.T) { testListenTCP(t, lt, lt.s2ip4) }) + t.Run("IPv6", func(t *testing.T) { testListenTCP(t, lt, lt.s2ip6) }) + }) + + t.Run("TUN", func(t *testing.T) { + lt := setupTwoClientTest(t, true) + t.Run("IPv4", func(t *testing.T) { testListenTCP(t, lt, lt.s2ip4) }) + t.Run("IPv6", func(t *testing.T) { testListenTCP(t, lt, lt.s2ip6) }) + }) +} + +// TestListenTCPDualStack tests TCP listeners with wildcard addresses (dual-stack) +// in both netstack and TUN modes. +func TestListenTCPDualStack(t *testing.T) { + testListenTCPDualStack := func(t *testing.T, lt *listenTest, dialIP netip.Addr) { + ln, err := lt.s2.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + _, portStr, err := net.SplitHostPort(ln.Addr().String()) + if err != nil { + t.Fatalf("parsing listener address %q: %v", ln.Addr().String(), err) + } + + echoErr := make(chan error, 1) + go func() { + conn, err := ln.Accept() + if err != nil { + echoErr <- err + return + } + defer conn.Close() + buf := make([]byte, 1024) + n, err := conn.Read(buf) + if err != nil { + echoErr <- err + return + } + _, err = conn.Write(buf[:n]) + if err != nil { + echoErr <- err + return + } + }() + + dialAddr := net.JoinHostPort(dialIP.String(), portStr) + conn, err := lt.s1.Dial(t.Context(), "tcp", dialAddr) + if err != nil { + t.Fatalf("Dial(%q) failed: %v", dialAddr, err) + } + defer conn.Close() + + want := "hello tcp dualstack" + if _, err := conn.Write([]byte(want)); err != nil { + t.Fatalf("Write failed: %v", err) + } + + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + got := make([]byte, 1024) + n, err := conn.Read(got) + if err != nil { + select { + case e := <-echoErr: + t.Fatalf("echo error: %v; read error: %v", e, err) + default: + t.Fatalf("Read failed: %v", err) + } + } + + if string(got[:n]) != want { + t.Errorf("got %q, want %q", got[:n], want) + } + } + + t.Run("Netstack", func(t *testing.T) { + lt := setupTwoClientTest(t, false) + t.Run("DialIPv4", func(t *testing.T) { testListenTCPDualStack(t, lt, lt.s2ip4) }) + t.Run("DialIPv6", func(t *testing.T) { testListenTCPDualStack(t, lt, lt.s2ip6) }) + }) + + t.Run("TUN", func(t *testing.T) { + lt := setupTwoClientTest(t, true) + t.Run("DialIPv4", func(t *testing.T) { testListenTCPDualStack(t, lt, lt.s2ip4) }) + t.Run("DialIPv6", func(t *testing.T) { testListenTCPDualStack(t, lt, lt.s2ip6) }) + }) +} + +// TestDialTCP tests TCP dialing from s2 to s1 in both netstack and TUN modes. +// In TUN mode, this verifies that outbound TCP connections and their replies +// are handled by netstack without packets escaping to the TUN. +func TestDialTCP(t *testing.T) { + testDialTCP := func(t *testing.T, lt *listenTest, listenIP netip.Addr) { + ln, err := lt.s1.Listen("tcp", netip.AddrPortFrom(listenIP, 0).String()) + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + echoErr := make(chan error, 1) + go func() { + conn, err := ln.Accept() + if err != nil { + echoErr <- err + return + } + defer conn.Close() + buf := make([]byte, 1024) + n, err := conn.Read(buf) + if err != nil { + echoErr <- err + return + } + _, err = conn.Write(buf[:n]) + if err != nil { + echoErr <- err + return + } + }() + + conn, err := lt.s2.Dial(t.Context(), "tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + defer conn.Close() + + want := "hello tcp dial" + if _, err := conn.Write([]byte(want)); err != nil { + t.Fatalf("Write failed: %v", err) + } + + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + got := make([]byte, 1024) + n, err := conn.Read(got) + if err != nil { + select { + case e := <-echoErr: + t.Fatalf("echo error: %v; read error: %v", e, err) + default: + t.Fatalf("Read failed: %v", err) + } + } + + if string(got[:n]) != want { + t.Errorf("got %q, want %q", got[:n], want) + } + } + + t.Run("Netstack", func(t *testing.T) { + lt := setupTwoClientTest(t, false) + t.Run("IPv4", func(t *testing.T) { testDialTCP(t, lt, lt.s1ip4) }) + t.Run("IPv6", func(t *testing.T) { testDialTCP(t, lt, lt.s1ip6) }) + }) + + t.Run("TUN", func(t *testing.T) { + lt := setupTwoClientTest(t, true) + + var escapedTCPPackets atomic.Int32 + var wg sync.WaitGroup + wg.Go(func() { + for pkt := range lt.tun.Inbound { + var p packet.Parsed + p.Decode(pkt) + if p.IPProto == ipproto.TCP { + escapedTCPPackets.Add(1) + t.Logf("TCP packet escaped to TUN: %v -> %v", p.Src, p.Dst) + } + } + }) + + t.Run("IPv4", func(t *testing.T) { testDialTCP(t, lt, lt.s1ip4) }) + t.Run("IPv6", func(t *testing.T) { testDialTCP(t, lt, lt.s1ip6) }) + + lt.tun.Close() + wg.Wait() + if escaped := escapedTCPPackets.Load(); escaped > 0 { + t.Errorf("%d TCP packets escaped to TUN", escaped) + } + }) +} + +// TestDialUDP tests UDP dialing from s2 to s1 in both netstack and TUN modes. +// In TUN mode, this verifies that outbound UDP connections register endpoints +// with gVisor, allowing reply packets to be routed through netstack instead of +// escaping to the TUN. +func TestDialUDP(t *testing.T) { + testDialUDP := func(t *testing.T, lt *listenTest, listenIP netip.Addr) { + pc, err := lt.s1.ListenPacket("udp", netip.AddrPortFrom(listenIP, 0).String()) + if err != nil { + t.Fatal(err) + } + defer pc.Close() + + echoErr := make(chan error, 1) + go func() { + buf := make([]byte, 1500) + n, addr, err := pc.ReadFrom(buf) + if err != nil { + echoErr <- err + return + } + _, err = pc.WriteTo(buf[:n], addr) + if err != nil { + echoErr <- err + return + } + }() + + conn, err := lt.s2.Dial(t.Context(), "udp", pc.LocalAddr().String()) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + defer conn.Close() + + want := "hello udp dial" + if _, err := conn.Write([]byte(want)); err != nil { + t.Fatalf("Write failed: %v", err) + } + + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + got := make([]byte, 1024) + n, err := conn.Read(got) + if err != nil { + select { + case e := <-echoErr: + t.Fatalf("echo error: %v; read error: %v", e, err) + default: + t.Fatalf("Read failed: %v", err) + } + } + + if string(got[:n]) != want { + t.Errorf("got %q, want %q", got[:n], want) + } + } + + t.Run("Netstack", func(t *testing.T) { + lt := setupTwoClientTest(t, false) + t.Run("IPv4", func(t *testing.T) { testDialUDP(t, lt, lt.s1ip4) }) + t.Run("IPv6", func(t *testing.T) { testDialUDP(t, lt, lt.s1ip6) }) + }) + + t.Run("TUN", func(t *testing.T) { + lt := setupTwoClientTest(t, true) + + var escapedUDPPackets atomic.Int32 + var wg sync.WaitGroup + wg.Go(func() { + for pkt := range lt.tun.Inbound { + var p packet.Parsed + p.Decode(pkt) + if p.IPProto == ipproto.UDP { + escapedUDPPackets.Add(1) + t.Logf("UDP packet escaped to TUN: %v -> %v", p.Src, p.Dst) + } + } + }) + + t.Run("IPv4", func(t *testing.T) { testDialUDP(t, lt, lt.s1ip4) }) + t.Run("IPv6", func(t *testing.T) { testDialUDP(t, lt, lt.s1ip6) }) + + lt.tun.Close() + wg.Wait() + if escaped := escapedUDPPackets.Load(); escaped > 0 { + t.Errorf("%d UDP packets escaped to TUN", escaped) + } + }) +} + +// buildDNSQuery builds a UDP/IP packet containing a DNS query for name to the +// Tailscale service IP (100.100.100.100 for IPv4, fd7a:115c:a1e0::53 for IPv6). +func buildDNSQuery(name string, srcIP netip.Addr) []byte { + qtype := byte(0x01) // Type A for IPv4 + if srcIP.Is6() { + qtype = 0x1c // Type AAAA for IPv6 + } + dns := []byte{ + 0x12, 0x34, // ID + 0x01, 0x00, // Flags: standard query, recursion desired + 0x00, 0x01, // QDCOUNT: 1 + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // ANCOUNT, NSCOUNT, ARCOUNT + } + for _, label := range strings.Split(name, ".") { + dns = append(dns, byte(len(label))) + dns = append(dns, label...) + } + dns = append(dns, 0x00, 0x00, qtype, 0x00, 0x01) // null, Type A/AAAA, Class IN + + if srcIP.Is4() { + h := packet.UDP4Header{ + IP4Header: packet.IP4Header{ + Src: srcIP, + Dst: netip.MustParseAddr("100.100.100.100"), + }, + SrcPort: 12345, + DstPort: 53, + } + return packet.Generate(h, dns) + } + h := packet.UDP6Header{ + IP6Header: packet.IP6Header{ + Src: srcIP, + Dst: netip.MustParseAddr("fd7a:115c:a1e0::53"), + }, + SrcPort: 12345, + DstPort: 53, + } + return packet.Generate(h, dns) +} + func TestDeps(t *testing.T) { tstest.Shard(t) deptest.DepChecker{ @@ -2110,3 +2876,232 @@ func TestResolveAuthKey(t *testing.T) { }) } } + +// TestSelfDial verifies that a single tsnet.Server can Dial its own Listen +// address. This is a regression test for a bug where self-addressed TCP SYN +// packets were sent to WireGuard (which has no peer for the node's own IP) +// and silently dropped, causing Dial to hang indefinitely. +func TestSelfDial(t *testing.T) { + tstest.Shard(t) + tstest.ResourceCheck(t) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + controlURL, _ := startControl(t) + s1, s1ip, _ := startServer(t, ctx, controlURL, "s1") + + ln, err := s1.Listen("tcp", ":8081") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + errc := make(chan error, 1) + connc := make(chan net.Conn, 1) + go func() { + c, err := ln.Accept() + if err != nil { + errc <- err + return + } + connc <- c + }() + + // Self-dial: the same server dials its own Tailscale IP. + w, err := s1.Dial(ctx, "tcp", fmt.Sprintf("%s:8081", s1ip)) + if err != nil { + t.Fatalf("self-dial failed: %v", err) + } + defer w.Close() + + var accepted net.Conn + select { + case accepted = <-connc: + case err := <-errc: + t.Fatalf("accept failed: %v", err) + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for accept") + } + defer accepted.Close() + + // Verify bidirectional data exchange. + want := "hello self" + if _, err := io.WriteString(w, want); err != nil { + t.Fatal(err) + } + got := make([]byte, len(want)) + if _, err := io.ReadFull(accepted, got); err != nil { + t.Fatal(err) + } + if string(got) != want { + t.Errorf("client->server: got %q, want %q", got, want) + } + + reply := "hello back" + if _, err := io.WriteString(accepted, reply); err != nil { + t.Fatal(err) + } + gotReply := make([]byte, len(reply)) + if _, err := io.ReadFull(w, gotReply); err != nil { + t.Fatal(err) + } + if string(gotReply) != reply { + t.Errorf("server->client: got %q, want %q", gotReply, reply) + } +} + +// TestListenUnspecifiedAddr verifies that listening on 0.0.0.0 or [::] works +// the same as listening on an empty host (":port"), accepting connections +// destined to the node's Tailscale IPs. +func TestListenUnspecifiedAddr(t *testing.T) { + testUnspec := func(t *testing.T, lt *listenTest, addr, dialPort string) { + ln, err := lt.s2.Listen("tcp", addr) + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + echoErr := make(chan error, 1) + go func() { + conn, err := ln.Accept() + if err != nil { + echoErr <- err + return + } + defer conn.Close() + buf := make([]byte, 1024) + n, err := conn.Read(buf) + if err != nil { + echoErr <- err + return + } + _, err = conn.Write(buf[:n]) + echoErr <- err + }() + + dialAddr := net.JoinHostPort(lt.s2ip4.String(), dialPort) + conn, err := lt.s1.Dial(t.Context(), "tcp", dialAddr) + if err != nil { + t.Fatalf("Dial(%q) failed: %v", dialAddr, err) + } + defer conn.Close() + want := "hello unspec" + if _, err := conn.Write([]byte(want)); err != nil { + t.Fatalf("Write failed: %v", err) + } + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + got := make([]byte, 1024) + n, err := conn.Read(got) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + if string(got[:n]) != want { + t.Errorf("got %q, want %q", got[:n], want) + } + if err := <-echoErr; err != nil { + t.Fatalf("echo error: %v", err) + } + } + + t.Run("Netstack", func(t *testing.T) { + lt := setupTwoClientTest(t, false) + t.Run("0.0.0.0", func(t *testing.T) { testUnspec(t, lt, "0.0.0.0:8080", "8080") }) + t.Run("::", func(t *testing.T) { testUnspec(t, lt, "[::]:8081", "8081") }) + }) + t.Run("TUN", func(t *testing.T) { + lt := setupTwoClientTest(t, true) + t.Run("0.0.0.0", func(t *testing.T) { testUnspec(t, lt, "0.0.0.0:8080", "8080") }) + t.Run("::", func(t *testing.T) { testUnspec(t, lt, "[::]:8081", "8081") }) + }) +} + +// TestListenMultipleEphemeralPorts verifies that calling Listen with port 0 +// multiple times allocates distinct ports, each of which can receive +// connections independently. +func TestListenMultipleEphemeralPorts(t *testing.T) { + testMultipleEphemeral := func(t *testing.T, lt *listenTest) { + const n = 3 + listeners := make([]net.Listener, n) + ports := make([]string, n) + for i := range n { + ln, err := lt.s2.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { ln.Close() }) + _, portStr, err := net.SplitHostPort(ln.Addr().String()) + if err != nil { + t.Fatalf("parsing Addr %q: %v", ln.Addr(), err) + } + if portStr == "0" { + t.Fatal("Addr() returned port 0; expected allocated port") + } + for j := range i { + if ports[j] == portStr { + t.Fatalf("listeners %d and %d both got port %s", j, i, portStr) + } + } + listeners[i] = ln + ports[i] = portStr + } + + // Verify each listener independently accepts connections. + for i := range n { + echoErr := make(chan error, 1) + go func() { + conn, err := listeners[i].Accept() + if err != nil { + echoErr <- err + return + } + defer conn.Close() + buf := make([]byte, 1024) + rn, err := conn.Read(buf) + if err != nil { + echoErr <- err + return + } + _, err = conn.Write(buf[:rn]) + echoErr <- err + }() + + dialAddr := net.JoinHostPort(lt.s2ip4.String(), ports[i]) + conn, err := lt.s1.Dial(t.Context(), "tcp", dialAddr) + if err != nil { + t.Fatalf("listener %d: Dial(%q) failed: %v", i, dialAddr, err) + } + want := fmt.Sprintf("hello port %d", i) + if _, err := conn.Write([]byte(want)); err != nil { + conn.Close() + t.Fatalf("listener %d: Write failed: %v", i, err) + } + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + got := make([]byte, 1024) + rn, err := conn.Read(got) + conn.Close() + if err != nil { + select { + case e := <-echoErr: + t.Fatalf("listener %d: echo error: %v; read error: %v", i, e, err) + default: + t.Fatalf("listener %d: Read failed: %v", i, err) + } + } + if string(got[:rn]) != want { + t.Errorf("listener %d: got %q, want %q", i, got[:rn], want) + } + if err := <-echoErr; err != nil { + t.Fatalf("listener %d: echo error: %v", i, err) + } + } + } + + t.Run("Netstack", func(t *testing.T) { + lt := setupTwoClientTest(t, false) + testMultipleEphemeral(t, lt) + }) + t.Run("TUN", func(t *testing.T) { + lt := setupTwoClientTest(t, true) + testMultipleEphemeral(t, lt) + }) +} diff --git a/tstest/integration/testcontrol/testcontrol.go b/tstest/integration/testcontrol/testcontrol.go index 447efb0c1b15d..1e24414903ae9 100644 --- a/tstest/integration/testcontrol/testcontrol.go +++ b/tstest/integration/testcontrol/testcontrol.go @@ -1,4 +1,4 @@ -// Copyright (c) Tailscale Inc & AUTHORS +// Copyright (c) Tailscale Inc & contributors // SPDX-License-Identifier: BSD-3-Clause // Package testcontrol contains a minimal control plane server for testing purposes. @@ -299,43 +299,6 @@ func (s *Server) addDebugMessage(nodeKeyDst key.NodePublic, msg any) bool { return sendUpdate(oldUpdatesCh, updateDebugInjection) } -// ForceNetmapUpdate waits for the node to get stuck in a map poll and then -// sends the current netmap (which may result in a redundant netmap). The -// intended use case is ensuring state changes propagate before running tests. -// -// This should only be called for nodes connected as streaming clients. Calling -// this with a non-streaming node will result in non-deterministic behavior. -// -// This function cannot guarantee that the node has processed the issued update, -// so tests should confirm processing by querying the node. By example: -// -// if err := s.ForceNetmapUpdate(node.Key()); err != nil { -// // handle error -// } -// for !updatesPresent(node.NetMap()) { -// time.Sleep(10 * time.Millisecond) -// } -func (s *Server) ForceNetmapUpdate(ctx context.Context, nodeKey key.NodePublic) error { - for { - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - if err := s.AwaitNodeInMapRequest(ctx, nodeKey); err != nil { - return fmt.Errorf("waiting for node to poll: %w", err) - } - mr, err := s.MapResponse(&tailcfg.MapRequest{NodeKey: nodeKey}) - if err != nil { - return fmt.Errorf("generating map response: %w", err) - } - if s.addDebugMessage(nodeKey, mr) { - return nil - } - // If we failed to send the map response, loop around and try again. - } -} - // Mark the Node key of every node as expired func (s *Server) SetExpireAllNodes(expired bool) { s.mu.Lock() @@ -589,8 +552,9 @@ func (s *Server) SetNodeCapMap(nodeKey key.NodePublic, capMap tailcfg.NodeCapMap // ] func (s *Server) SetGlobalAppCaps(appCaps tailcfg.PeerCapMap) { s.mu.Lock() + defer s.mu.Unlock() s.globalAppCaps = appCaps - s.mu.Unlock() + s.updateLocked("SetGlobalAppCaps", s.nodeIDsLocked(0)) } // AddDNSRecords adds records to the server's DNS config. @@ -601,6 +565,7 @@ func (s *Server) AddDNSRecords(records ...tailcfg.DNSRecord) { s.DNSConfig = new(tailcfg.DNSConfig) } s.DNSConfig.ExtraRecords = append(s.DNSConfig.ExtraRecords, records...) + s.updateLocked("AddDNSRecords", s.nodeIDsLocked(0)) } // nodeIDsLocked returns the node IDs of all nodes in the server, except @@ -1110,9 +1075,7 @@ func sendUpdate(dst chan<- updateType, updateType updateType) bool { } } -func (s *Server) UpdateNode(n *tailcfg.Node) (peersToUpdate []tailcfg.NodeID) { - s.mu.Lock() - defer s.mu.Unlock() +func (s *Server) updateNodeLocked(n *tailcfg.Node) (peersToUpdate []tailcfg.NodeID) { if n.Key.IsZero() { panic("zero nodekey") } @@ -1120,6 +1083,15 @@ func (s *Server) UpdateNode(n *tailcfg.Node) (peersToUpdate []tailcfg.NodeID) { return s.nodeIDsLocked(n.ID) } +// UpdateNode updates or adds the input node, then triggers a netmap update for +// all attached streaming clients. +func (s *Server) UpdateNode(n *tailcfg.Node) { + s.mu.Lock() + defer s.mu.Unlock() + s.updateNodeLocked(n) + s.updateLocked("UpdateNode", s.nodeIDsLocked(0)) +} + func (s *Server) incrInServeMap(delta int) { s.mu.Lock() defer s.mu.Unlock() @@ -1178,7 +1150,9 @@ func (s *Server) serveMap(w http.ResponseWriter, r *http.Request, mkey key.Machi } } } - peersToUpdate = s.UpdateNode(node) + s.mu.Lock() + peersToUpdate = s.updateNodeLocked(node) + s.mu.Unlock() } nodeID := node.ID @@ -1327,16 +1301,19 @@ func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse, s.mu.Lock() nodeCapMap := maps.Clone(s.nodeCapMaps[nk]) + var dns *tailcfg.DNSConfig + if s.DNSConfig != nil { + dns = s.DNSConfig.Clone() + } + magicDNSDomain := s.MagicDNSDomain s.mu.Unlock() node.CapMap = nodeCapMap node.Capabilities = append(node.Capabilities, tailcfg.NodeAttrDisableUPnP) t := time.Date(2020, 8, 3, 0, 0, 0, 1, time.UTC) - dns := s.DNSConfig - if dns != nil && s.MagicDNSDomain != "" { - dns = dns.Clone() - dns.CertDomains = append(dns.CertDomains, node.Hostinfo.Hostname()+"."+s.MagicDNSDomain) + if dns != nil && magicDNSDomain != "" { + dns.CertDomains = append(dns.CertDomains, node.Hostinfo.Hostname()+"."+magicDNSDomain) } res = &tailcfg.MapResponse{