diff --git a/.github/workflows/cigocacher.yml b/.github/workflows/cigocacher.yml index 15aec8af90904..9e7f01725958e 100644 --- a/.github/workflows/cigocacher.yml +++ b/.github/workflows/cigocacher.yml @@ -24,7 +24,7 @@ jobs: ./tool/go build -o "${OUT}" ./cmd/cigocacher/ tar -zcf cigocacher-${{ matrix.GOOS }}-${{ matrix.GOARCH }}.tar.gz "${OUT}" - - uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0 + - uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 with: name: cigocacher-${{ matrix.GOOS }}-${{ matrix.GOARCH }} path: cigocacher-${{ matrix.GOOS }}-${{ matrix.GOARCH }}.tar.gz @@ -36,7 +36,7 @@ jobs: contents: write steps: - name: Download all artifacts - uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 with: pattern: 'cigocacher-*' merge-multiple: true diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 51bae5a068df5..abe6a2c3ae684 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -55,7 +55,7 @@ jobs: # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL - uses: github/codeql-action/init@c793b717bc78562f491db7b0e93a3a178b099162 # v4.32.5 + uses: github/codeql-action/init@38697555549f1db7851b81482ff19f1fa5c4fedc # v4.34.1 with: languages: ${{ matrix.language }} # If you wish to specify custom queries, you can do so here or in a config file. @@ -66,7 +66,7 @@ jobs: # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). # If this step fails, then you should remove it and run the build manually (see below) - name: Autobuild - uses: github/codeql-action/autobuild@c793b717bc78562f491db7b0e93a3a178b099162 # v4.32.5 + uses: github/codeql-action/autobuild@38697555549f1db7851b81482ff19f1fa5c4fedc # v4.34.1 # â„šī¸ Command-line programs to run using the OS shell. # 📚 https://git.io/JvXDl @@ -80,4 +80,4 @@ jobs: # make release - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@c793b717bc78562f491db7b0e93a3a178b099162 # v4.32.5 + uses: github/codeql-action/analyze@38697555549f1db7851b81482ff19f1fa5c4fedc # v4.34.1 diff --git a/.github/workflows/natlab-integrationtest.yml b/.github/workflows/natlab-integrationtest.yml index 162153cb23293..6c0d4957543c3 100644 --- a/.github/workflows/natlab-integrationtest.yml +++ b/.github/workflows/natlab-integrationtest.yml @@ -22,12 +22,23 @@ jobs: steps: - name: Check out code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + - name: Enable KVM + run: | + echo 'KERNEL=="kvm", GROUP="kvm", MODE="0666", OPTIONS+="static_node=kvm"' | sudo tee /etc/udev/rules.d/99-kvm4all.rules + sudo udevadm control --reload-rules + sudo udevadm trigger --name-match=kvm - name: Install qemu run: | sudo rm -f /var/lib/man-db/auto-update sudo apt-get -y update sudo apt-get -y remove man-db sudo apt-get install -y qemu-system-x86 qemu-utils + - name: Build VM image + # The test will build this if missing, but we do it explicitly + # to avoid cutting into the go test -timeout budget, and to + # fail earlier with a clearer error if the image build breaks. + run: | + make -C gokrazy natlab - name: Run natlab integration tests run: | ./tool/go test -v -run=^TestEasyEasy$ -timeout=3m -count=1 ./tstest/integration/nat --run-vm-tests diff --git a/.github/workflows/request-dataplane-review.yml b/.github/workflows/request-dataplane-review.yml index 2b66fc7899428..78bd8ff585bff 100644 --- a/.github/workflows/request-dataplane-review.yml +++ b/.github/workflows/request-dataplane-review.yml @@ -18,7 +18,7 @@ jobs: - name: Check out code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Get access token - uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2.2.1 + uses: actions/create-github-app-token@f8d387b68d61c58ab83c6c016672934102569859 # v3.0.0 id: generate-token with: # Get token for app: https://github.com/apps/change-visibility-bot diff --git a/.github/workflows/ssh-integrationtest.yml b/.github/workflows/ssh-integrationtest.yml index afe2dd2f74683..84432cd729418 100644 --- a/.github/workflows/ssh-integrationtest.yml +++ b/.github/workflows/ssh-integrationtest.yml @@ -1,5 +1,5 @@ -# Run the ssh integration tests with `make sshintegrationtest`. -# These tests can also be running locally. +# Run the ssh integration tests in various Docker containers. +# These tests can also be run locally via `make sshintegrationtest`. name: "ssh-integrationtest" concurrency: @@ -15,9 +15,25 @@ on: jobs: ssh-integrationtest: runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + include: + - base: "ubuntu:focal" + tag: "ssh-ubuntu-focal" + - base: "ubuntu:jammy" + tag: "ssh-ubuntu-jammy" + - base: "ubuntu:noble" + tag: "ssh-ubuntu-noble" + - base: "alpine:latest" + tag: "ssh-alpine-latest" steps: - name: Check out code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - - name: Run SSH integration tests + - name: Build test binaries run: | - make sshintegrationtest \ No newline at end of file + GOOS=linux GOARCH=amd64 CGO_ENABLED=0 ./tool/go test -tags integrationtest -c ./ssh/tailssh -o ssh/tailssh/testcontainers/tailssh.test + GOOS=linux GOARCH=amd64 CGO_ENABLED=0 ./tool/go build -o ssh/tailssh/testcontainers/tailscaled ./cmd/tailscaled + - name: Run SSH integration tests (${{ matrix.base }}) + run: | + docker build --build-arg="BASE=${{ matrix.base }}" -t "${{ matrix.tag }}" ssh/tailssh/testcontainers diff --git a/.github/workflows/update-flake.yml b/.github/workflows/update-flake.yml index 4c0da7831b5ba..ce77cf651ad42 100644 --- a/.github/workflows/update-flake.yml +++ b/.github/workflows/update-flake.yml @@ -23,11 +23,11 @@ jobs: - name: Check out code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - - name: Run update-flakes - run: ./update-flake.sh + - name: Run updateflakes + run: ./tool/go run ./tool/updateflakes - name: Get access token - uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2.2.1 + uses: actions/create-github-app-token@f8d387b68d61c58ab83c6c016672934102569859 # v3.0.0 id: generate-token with: # Get token for app: https://github.com/apps/tailscale-code-updater @@ -41,8 +41,8 @@ jobs: author: Flakes Updater committer: Flakes Updater branch: flakes - commit-message: "go.mod.sri: update SRI hash for go.mod changes" - title: "go.mod.sri: update SRI hash for go.mod changes" + commit-message: "flakehashes.json: update SRI hash for go.mod changes" + title: "flakehashes.json: update SRI hash for go.mod changes" body: Triggered by ${{ github.repository }}@${{ github.sha }} signoff: true delete-branch: true diff --git a/.github/workflows/update-webclient-prebuilt.yml b/.github/workflows/update-webclient-prebuilt.yml index a3d78e1a5b4a8..5bb0573a1f18c 100644 --- a/.github/workflows/update-webclient-prebuilt.yml +++ b/.github/workflows/update-webclient-prebuilt.yml @@ -23,7 +23,7 @@ jobs: ./tool/go mod tidy - name: Get access token - uses: actions/create-github-app-token@29824e69f54612133e76f7eaac726eef6c875baf # v2.2.1 + uses: actions/create-github-app-token@f8d387b68d61c58ab83c6c016672934102569859 # v3.0.0 id: generate-token with: # Get token for app: https://github.com/apps/tailscale-code-updater diff --git a/.github/workflows/vet.yml b/.github/workflows/vet.yml index 574852e62beee..c03190e4fcf1a 100644 --- a/.github/workflows/vet.yml +++ b/.github/workflows/vet.yml @@ -36,4 +36,10 @@ jobs: - name: Run 'go vet' working-directory: src - run: ./tool/go vet -vettool=/tmp/vettool tailscale.com/... + # Use listpkgs --ignore-3p to skip tempfork/ packages, which + # intentionally match upstream and may not follow our style rules. + # Must use ./... instead of tailscale.com/... because the latter will + # include the v2 go client (tailscale.com/client/tailscale/v2) if it's + # a dependency in our go.mod file. Possibly a go vet bug, but avoid + # cross-repo vetting for now so we can safely add the dependency. + run: ./tool/go vet -vettool=/tmp/vettool $(./tool/go run ./tool/listpkgs --ignore-3p ./...) diff --git a/.gitignore b/.gitignore index 4bfabc80f0415..e1f6be02e002f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,12 +1,15 @@ # Binaries for programs and plugins *~ *.tmp -*.exe *.dll *.so *.dylib *.spk +*.exe +# tool/go.exe is built specially and committed. +!/tool/go.exe + cmd/tailscale/tailscale cmd/tailscaled/tailscaled ssh/tailssh/testcontainers/tailscaled diff --git a/.golangci.yml b/.golangci.yml index eb34f9d9efc76..ff8bd07228677 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -10,9 +10,15 @@ linters: enable: - bidichk - govet + - importas - misspell - revive settings: + importas: + no-unaliased: true + alias: + - pkg: github.com/tailscale/gliderssh + alias: gliderssh # Matches what we use in corp as of 2023-12-07 govet: enable: diff --git a/Makefile b/Makefile index b78ef046913a7..0efd57fb486d6 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,7 @@ vet: ## Run go vet tidy: ## Run go mod tidy and update nix flake hashes ./tool/go mod tidy - ./update-flake.sh + ./tool/go run ./tool/updateflakes lint: ## Run golangci-lint ./tool/go run github.com/golangci/golangci-lint/cmd/golangci-lint run @@ -137,10 +137,12 @@ publishdevproxy: check-image-repo ## Build and publish k8s-proxy image to locati sshintegrationtest: ## Run the SSH integration tests in various Docker containers @GOOS=linux GOARCH=amd64 CGO_ENABLED=0 ./tool/go test -tags integrationtest -c ./ssh/tailssh -o ssh/tailssh/testcontainers/tailssh.test && \ GOOS=linux GOARCH=amd64 CGO_ENABLED=0 ./tool/go build -o ssh/tailssh/testcontainers/tailscaled ./cmd/tailscaled && \ - echo "Testing on ubuntu:focal" && docker build --build-arg="BASE=ubuntu:focal" -t ssh-ubuntu-focal ssh/tailssh/testcontainers && \ - echo "Testing on ubuntu:jammy" && docker build --build-arg="BASE=ubuntu:jammy" -t ssh-ubuntu-jammy ssh/tailssh/testcontainers && \ - echo "Testing on ubuntu:noble" && docker build --build-arg="BASE=ubuntu:noble" -t ssh-ubuntu-noble ssh/tailssh/testcontainers && \ - echo "Testing on alpine:latest" && docker build --build-arg="BASE=alpine:latest" -t ssh-alpine-latest ssh/tailssh/testcontainers + echo "Testing on ubuntu:focal, ubuntu:jammy, ubuntu:noble, alpine:latest (in parallel)" && \ + docker build --build-arg="BASE=ubuntu:focal" -t ssh-ubuntu-focal ssh/tailssh/testcontainers & \ + docker build --build-arg="BASE=ubuntu:jammy" -t ssh-ubuntu-jammy ssh/tailssh/testcontainers & \ + docker build --build-arg="BASE=ubuntu:noble" -t ssh-ubuntu-noble ssh/tailssh/testcontainers & \ + docker build --build-arg="BASE=alpine:latest" -t ssh-alpine-latest ssh/tailssh/testcontainers & \ + wait .PHONY: generate generate: ## Generate code diff --git a/README.md b/README.md index 70b92d411b9de..1d8208a867814 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ not open source. ## Building -We always require the latest Go release, currently Go 1.25. (While we build +We always require the latest Go release, currently Go 1.26. (While we build releases with our [Go fork](https://github.com/tailscale/go/), its use is not required.) diff --git a/appc/appconnector_test.go b/appc/appconnector_test.go index a860da6a7c737..c58aa80410869 100644 --- a/appc/appconnector_test.go +++ b/appc/appconnector_test.go @@ -698,7 +698,7 @@ func TestRateLogger(t *testing.T) { wasCalled = true }) - for i := 0; i < 3; i++ { + for range 3 { clock.Advance(1 * time.Millisecond) rl.update(0) if wasCalled { @@ -720,7 +720,7 @@ func TestRateLogger(t *testing.T) { wasCalled = true }) - for i := 0; i < 3; i++ { + for range 3 { clock.Advance(1 * time.Minute) rl.update(0) if wasCalled { @@ -736,6 +736,7 @@ func TestRateLogger(t *testing.T) { } func TestRouteStoreMetrics(t *testing.T) { + clientmetric.ResetForTest(t) metricStoreRoutes(1, 1) metricStoreRoutes(1, 1) // the 1 buckets value should be 2 metricStoreRoutes(5, 5) // the 5 buckets value should be 1 diff --git a/appc/conn25.go b/appc/conn25.go index 08b2a1ade6826..62cb70017824a 100644 --- a/appc/conn25.go +++ b/appc/conn25.go @@ -6,7 +6,9 @@ package appc import ( "cmp" "slices" + "strings" + "tailscale.com/ipn/ipnext" "tailscale.com/tailcfg" "tailscale.com/types/appctype" "tailscale.com/util/mak" @@ -15,9 +17,46 @@ import ( const AppConnectorsExperimentalAttrName = "tailscale.com/app-connectors-experimental" +func isPeerEligibleConnector(peer tailcfg.NodeView) bool { + if !peer.Valid() || !peer.Hostinfo().Valid() { + return false + } + isConn, _ := peer.Hostinfo().AppConnector().Get() + return isConn +} + +func sortByPreference(ns []tailcfg.NodeView) { + // The ordering of the nodes is semantic (callers use the first node they can + // get a peer api url for). We don't (currently 2026-02-27) have any + // preference over which node is chosen as long as it's consistent. In the + // future we anticipate integrating with traffic steering. + slices.SortFunc(ns, func(a, b tailcfg.NodeView) int { + return cmp.Compare(a.ID(), b.ID()) + }) +} + +// PickConnector returns peers the backend knows about that match the app, in order of preference to use as +// a connector. +func PickConnector(nb ipnext.NodeBackend, app appctype.Conn25Attr) []tailcfg.NodeView { + appTagsSet := set.SetOf(app.Connectors) + matches := nb.AppendMatchingPeers(nil, func(n tailcfg.NodeView) bool { + if !isPeerEligibleConnector(n) { + return false + } + for _, t := range n.Tags().All() { + if appTagsSet.Contains(t) { + return true + } + } + return false + }) + sortByPreference(matches) + return matches +} + // PickSplitDNSPeers looks at the netmap peers capabilities and finds which peers // want to be connectors for which domains. -func PickSplitDNSPeers(hasCap func(c tailcfg.NodeCapability) bool, self tailcfg.NodeView, peers map[tailcfg.NodeID]tailcfg.NodeView) map[string][]tailcfg.NodeView { +func PickSplitDNSPeers(hasCap func(c tailcfg.NodeCapability) bool, self tailcfg.NodeView, peers map[tailcfg.NodeID]tailcfg.NodeView, isSelfEligibleConnector bool) map[string][]tailcfg.NodeView { var m map[string][]tailcfg.NodeView if !hasCap(AppConnectorsExperimentalAttrName) { return m @@ -26,25 +65,43 @@ func PickSplitDNSPeers(hasCap func(c tailcfg.NodeCapability) bool, self tailcfg. if err != nil { return m } - tagToDomain := make(map[string][]string) + + // We strip the leading *. from any domains because the OS treats all domains + // that we pass to it as wildcard domains, and the OS would treat the * character + // as a literal domain component instead of treating it as a wildcard. + // We also use a Set to deduplicate the domains we pass to the OS in case removing + // the *. prefix resulted in duplicate entries. + tagToDomain := make(map[string]set.Set[string]) + selfTags := set.SetOf(self.Tags().AsSlice()) + selfRoutedDomains := set.Set[string]{} for _, app := range apps { + domains := make(set.Set[string]) + for _, domain := range app.Domains { + domains.Add(strings.ToLower(strings.TrimPrefix(domain, "*."))) + } for _, tag := range app.Connectors { - tagToDomain[tag] = append(tagToDomain[tag], app.Domains...) + if tagToDomain[tag] == nil { + tagToDomain[tag] = set.Set[string]{} + } + tagToDomain[tag].AddSet(domains) + if isSelfEligibleConnector && selfTags.Contains(tag) { + selfRoutedDomains.AddSet(domains) + } } } // NodeIDs are Comparable, and we have a map of NodeID to NodeView anyway, so // use a Set of NodeIDs to deduplicate, and populate into a []NodeView later. var work map[string]set.Set[tailcfg.NodeID] for _, peer := range peers { - if !peer.Valid() || !peer.Hostinfo().Valid() { - continue - } - if isConn, _ := peer.Hostinfo().AppConnector().Get(); !isConn { + if !isPeerEligibleConnector(peer) { continue } for _, t := range peer.Tags().All() { domains := tagToDomain[t] - for _, domain := range domains { + for domain := range domains { + if selfRoutedDomains.Contains(domain) { + continue + } if work[domain] == nil { mak.Set(&work, domain, set.Set[tailcfg.NodeID]{}) } @@ -60,12 +117,7 @@ func PickSplitDNSPeers(hasCap func(c tailcfg.NodeCapability) bool, self tailcfg. for id := range ids { nodes = append(nodes, peers[id]) } - // The ordering of the nodes in the map vals is semantic (dnsConfigForNetmap uses the first node it can - // get a peer api url for as its split dns target). We can think of it as a preference order, except that - // we don't (currently 2026-01-14) have any preference over which node is chosen. - slices.SortFunc(nodes, func(a, b tailcfg.NodeView) int { - return cmp.Compare(a.ID(), b.ID()) - }) + sortByPreference(nodes) mak.Set(&m, domain, nodes) } return m diff --git a/appc/conn25_test.go b/appc/conn25_test.go index a9cb0fb7ebf9c..dd98312ca638d 100644 --- a/appc/conn25_test.go +++ b/appc/conn25_test.go @@ -8,6 +8,8 @@ import ( "reflect" "testing" + "github.com/google/go-cmp/cmp" + "tailscale.com/ipn/ipnext" "tailscale.com/tailcfg" "tailscale.com/types/appctype" "tailscale.com/types/opt" @@ -30,6 +32,8 @@ func TestPickSplitDNSPeers(t *testing.T) { appTwoBytes := getBytesForAttr("app2", []string{"a.example.com"}, []string{"tag:two"}) appThreeBytes := getBytesForAttr("app3", []string{"woo.b.example.com", "hoo.b.example.com"}, []string{"tag:three1", "tag:three2"}) appFourBytes := getBytesForAttr("app4", []string{"woo.b.example.com", "c.example.com"}, []string{"tag:four1", "tag:four2"}) + appFiveBytes := getBytesForAttr("app5", []string{"*.example.com", "example.com"}, []string{"tag:one"}) + appSixBytes := getBytesForAttr("app6", []string{"*.Example.com", "EXAMPLE.com", "EXAMPLE.COM"}, []string{"tag:one"}) makeNodeView := func(id tailcfg.NodeID, name string, tags []string) tailcfg.NodeView { return (&tailcfg.Node{ @@ -45,10 +49,12 @@ func TestPickSplitDNSPeers(t *testing.T) { nvp4 := makeNodeView(4, "p4", []string{"tag:two", "tag:three2", "tag:four2"}) for _, tt := range []struct { - name string - want map[string][]tailcfg.NodeView - peers []tailcfg.NodeView - config []tailcfg.RawMessage + name string + peers []tailcfg.NodeView + config []tailcfg.RawMessage + isEligibleConnector bool + selfTags []string + want map[string][]tailcfg.NodeView }{ { name: "empty", @@ -109,6 +115,128 @@ func TestPickSplitDNSPeers(t *testing.T) { "c.example.com": {nvp2, nvp4}, }, }, + { + name: "self-connector-exclude-self-domains", + config: []tailcfg.RawMessage{ + tailcfg.RawMessage(appOneBytes), + tailcfg.RawMessage(appTwoBytes), + tailcfg.RawMessage(appThreeBytes), + tailcfg.RawMessage(appFourBytes), + }, + peers: []tailcfg.NodeView{ + nvp1, + nvp2, + nvp3, + nvp4, + }, + isEligibleConnector: true, + selfTags: []string{"tag:three1"}, + want: map[string][]tailcfg.NodeView{ + // woo.b.example.com and hoo.b.example.com are covered + // by tag:three1, and so is this self-node. + // So those domains should not be routed to peers. + // woo.b.example.com is also covered by another tag, + // but still not included since this connector can route to it. + "example.com": {nvp1}, + "a.example.com": {nvp3, nvp4}, + "c.example.com": {nvp2, nvp4}, + }, + }, + { + name: "self-eligible-connector-no-matching-tag-include-all-domains", + config: []tailcfg.RawMessage{ + tailcfg.RawMessage(appOneBytes), + tailcfg.RawMessage(appTwoBytes), + tailcfg.RawMessage(appThreeBytes), + tailcfg.RawMessage(appFourBytes), + }, + peers: []tailcfg.NodeView{ + nvp1, + nvp2, + nvp3, + nvp4, + }, + isEligibleConnector: true, + selfTags: []string{"tag:unrelated"}, + want: map[string][]tailcfg.NodeView{ + // Self has prefs set but no tags matching any app, + // so no domains are self-routed and all appear. + "example.com": {nvp1}, + "a.example.com": {nvp3, nvp4}, + "woo.b.example.com": {nvp2, nvp3, nvp4}, + "hoo.b.example.com": {nvp3, nvp4}, + "c.example.com": {nvp2, nvp4}, + }, + }, + { + name: "self-not-eligible-connector-but-tagged-include-all-domains", + config: []tailcfg.RawMessage{ + tailcfg.RawMessage(appOneBytes), + tailcfg.RawMessage(appTwoBytes), + tailcfg.RawMessage(appThreeBytes), + tailcfg.RawMessage(appFourBytes), + }, + peers: []tailcfg.NodeView{ + nvp1, + nvp2, + nvp3, + nvp4, + }, + selfTags: []string{"tag:three1"}, + want: map[string][]tailcfg.NodeView{ + // Even though this self node has a tag for an app + // the prefs don't advertise as connector, so + // should still route through other connectors. + "example.com": {nvp1}, + "a.example.com": {nvp3, nvp4}, + "woo.b.example.com": {nvp2, nvp3, nvp4}, + "hoo.b.example.com": {nvp3, nvp4}, + "c.example.com": {nvp2, nvp4}, + }, + }, + { + name: "wildcards-are-stripped-and-deduped", + config: []tailcfg.RawMessage{ + tailcfg.RawMessage(appOneBytes), + tailcfg.RawMessage(appFiveBytes), + }, + peers: []tailcfg.NodeView{ + nvp1, + }, + want: map[string][]tailcfg.NodeView{ + // All the domains should be normalized to example.com + "example.com": {nvp1}, + }, + }, + { + name: "domains-are-normalized-and-deduped", + config: []tailcfg.RawMessage{ + tailcfg.RawMessage(appSixBytes), + }, + peers: []tailcfg.NodeView{ + nvp1, + }, + want: map[string][]tailcfg.NodeView{ + // All the domains should be normalized to example.com + "example.com": {nvp1}, + }, + }, + { + name: "sub-domains-and-top-domains-do-not-collide", + config: []tailcfg.RawMessage{ + tailcfg.RawMessage(appTwoBytes), + tailcfg.RawMessage(appFiveBytes), + }, + peers: []tailcfg.NodeView{ + nvp1, + nvp3, + }, + want: map[string][]tailcfg.NodeView{ + // The sub.example.com should remain distinct from example.com + "example.com": {nvp1}, + "a.example.com": {nvp3}, + }, + }, } { t.Run(tt.name, func(t *testing.T) { selfNode := &tailcfg.Node{} @@ -117,6 +245,7 @@ func TestPickSplitDNSPeers(t *testing.T) { tailcfg.NodeCapability(AppConnectorsExperimentalAttrName): tt.config, } } + selfNode.Tags = append(selfNode.Tags, tt.selfTags...) selfView := selfNode.View() peers := map[tailcfg.NodeID]tailcfg.NodeView{} for _, p := range tt.peers { @@ -124,10 +253,165 @@ func TestPickSplitDNSPeers(t *testing.T) { } got := PickSplitDNSPeers(func(_ tailcfg.NodeCapability) bool { return true - }, selfView, peers) + }, selfView, peers, tt.isEligibleConnector) + if !reflect.DeepEqual(got, tt.want) { t.Fatalf("got %v, want %v", got, tt.want) } }) } } + +type testNodeBackend struct { + ipnext.NodeBackend + peers []tailcfg.NodeView +} + +func (nb *testNodeBackend) AppendMatchingPeers(base []tailcfg.NodeView, pred func(tailcfg.NodeView) bool) []tailcfg.NodeView { + for _, p := range nb.peers { + if pred(p) { + base = append(base, p) + } + } + return base +} + +func (nb *testNodeBackend) PeerHasPeerAPI(p tailcfg.NodeView) bool { + return true +} + +func TestPickConnector(t *testing.T) { + exampleApp := appctype.Conn25Attr{ + Name: "example", + Connectors: []string{"tag:example"}, + Domains: []string{"example.com"}, + } + + nvWithConnectorSet := func(id tailcfg.NodeID, isConnector bool, tags ...string) tailcfg.NodeView { + return (&tailcfg.Node{ + ID: id, + Tags: tags, + Hostinfo: (&tailcfg.Hostinfo{AppConnector: opt.NewBool(isConnector)}).View(), + }).View() + } + + nv := func(id tailcfg.NodeID, tags ...string) tailcfg.NodeView { + return nvWithConnectorSet(id, true, tags...) + } + + for _, tt := range []struct { + name string + candidates []tailcfg.NodeView + app appctype.Conn25Attr + want []tailcfg.NodeView + }{ + { + name: "empty-everything", + candidates: []tailcfg.NodeView{}, + app: appctype.Conn25Attr{}, + want: nil, + }, + { + name: "empty-candidates", + candidates: []tailcfg.NodeView{}, + app: exampleApp, + want: nil, + }, + { + name: "empty-app", + candidates: []tailcfg.NodeView{nv(1, "tag:example")}, + app: appctype.Conn25Attr{}, + want: nil, + }, + { + name: "one-matches", + candidates: []tailcfg.NodeView{nv(1, "tag:example")}, + app: exampleApp, + want: []tailcfg.NodeView{nv(1, "tag:example")}, + }, + { + name: "invalid-candidate", + candidates: []tailcfg.NodeView{ + {}, + nv(1, "tag:example"), + }, + app: exampleApp, + want: []tailcfg.NodeView{ + nv(1, "tag:example"), + }, + }, + { + name: "no-host-info", + candidates: []tailcfg.NodeView{ + (&tailcfg.Node{ + ID: 1, + Tags: []string{"tag:example"}, + }).View(), + nv(2, "tag:example"), + }, + app: exampleApp, + want: []tailcfg.NodeView{nv(2, "tag:example")}, + }, + { + name: "not-a-connector", + candidates: []tailcfg.NodeView{nvWithConnectorSet(1, false, "tag:example.com"), nv(2, "tag:example")}, + app: exampleApp, + want: []tailcfg.NodeView{nv(2, "tag:example")}, + }, + { + name: "without-matches", + candidates: []tailcfg.NodeView{nv(1, "tag:woo"), nv(2, "tag:example")}, + app: exampleApp, + want: []tailcfg.NodeView{nv(2, "tag:example")}, + }, + { + name: "multi-tags", + candidates: []tailcfg.NodeView{nv(1, "tag:woo", "tag:hoo"), nv(2, "tag:woo", "tag:example")}, + app: exampleApp, + want: []tailcfg.NodeView{nv(2, "tag:woo", "tag:example")}, + }, + { + name: "multi-matches", + candidates: []tailcfg.NodeView{nv(1, "tag:woo", "tag:hoo"), nv(2, "tag:woo", "tag:example"), nv(3, "tag:example1", "tag:example")}, + app: appctype.Conn25Attr{ + Name: "example2", + Connectors: []string{"tag:example1", "tag:example"}, + Domains: []string{"example.com"}, + }, + want: []tailcfg.NodeView{nv(2, "tag:woo", "tag:example"), nv(3, "tag:example1", "tag:example")}, + }, + { + name: "bit-of-everything", + candidates: []tailcfg.NodeView{ + nv(3, "tag:woo", "tag:hoo"), + {}, + nv(2, "tag:woo", "tag:example"), + nvWithConnectorSet(4, false, "tag:example"), + nv(1, "tag:example1", "tag:example"), + nv(7, "tag:example1", "tag:example"), + nvWithConnectorSet(5, false), + nv(6), + nvWithConnectorSet(8, false, "tag:example"), + nvWithConnectorSet(9, false), + nvWithConnectorSet(10, false), + }, + app: appctype.Conn25Attr{ + Name: "example2", + Connectors: []string{"tag:example1", "tag:example", "tag:example2"}, + Domains: []string{"example.com"}, + }, + want: []tailcfg.NodeView{ + nv(1, "tag:example1", "tag:example"), + nv(2, "tag:woo", "tag:example"), + nv(7, "tag:example1", "tag:example"), + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + got := PickConnector(&testNodeBackend{peers: tt.candidates}, tt.app) + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Fatalf("PickConnectors (-want, +got):\n%s", diff) + } + }) + } +} diff --git a/build_docker.sh b/build_docker.sh index 4552f8d8ee0d3..c460668f4ab65 100755 --- a/build_docker.sh +++ b/build_docker.sh @@ -4,7 +4,7 @@ # github.com/tailscale/mkctr. # By default the images will be tagged with the current version and git # hash of this repository as produced by ./cmd/mkversion. -# This is the image build mechanim used to build the official Tailscale +# This is the image build mechanism used to build the official Tailscale # container images. # # If you want to build local images for testing, you can use make, which provides few convenience wrappers around this script. @@ -38,6 +38,7 @@ TARGET="${TARGET:-${DEFAULT_TARGET}}" TAGS="${TAGS:-${DEFAULT_TAGS}}" BASE="${BASE:-${DEFAULT_BASE}}" PLATFORM="${PLATFORM:-}" # default to all platforms +GOARCH="${GOARCH:-arm,arm64,amd64,386,riscv64}" FILES="${FILES:-}" # default to no extra files # OCI annotations that will be added to the image. # https://github.com/opencontainers/image-spec/blob/main/annotations.md @@ -62,6 +63,7 @@ case "$TARGET" in --repos="${REPOS}" \ --push="${PUSH}" \ --target="${PLATFORM}" \ + --goarch="${GOARCH}" \ --annotations="${ANNOTATIONS}" \ --files="${FILES}" \ /usr/local/bin/containerboot @@ -81,6 +83,7 @@ case "$TARGET" in --repos="${REPOS}" \ --push="${PUSH}" \ --target="${PLATFORM}" \ + --goarch="${GOARCH}" \ --annotations="${ANNOTATIONS}" \ --files="${FILES}" \ /usr/local/bin/operator @@ -100,6 +103,7 @@ case "$TARGET" in --repos="${REPOS}" \ --push="${PUSH}" \ --target="${PLATFORM}" \ + --goarch="${GOARCH}" \ --annotations="${ANNOTATIONS}" \ --files="${FILES}" \ /usr/local/bin/k8s-nameserver @@ -119,6 +123,7 @@ case "$TARGET" in --repos="${REPOS}" \ --push="${PUSH}" \ --target="${PLATFORM}" \ + --goarch="${GOARCH}" \ --annotations="${ANNOTATIONS}" \ --files="${FILES}" \ /usr/local/bin/tsidp @@ -138,6 +143,7 @@ case "$TARGET" in --repos="${REPOS}" \ --push="${PUSH}" \ --target="${PLATFORM}" \ + --goarch="${GOARCH}" \ --annotations="${ANNOTATIONS}" \ --files="${FILES}" \ /usr/local/bin/k8s-proxy diff --git a/cache_key_test.go b/cache_key_test.go new file mode 100644 index 0000000000000..8600bcd719f04 --- /dev/null +++ b/cache_key_test.go @@ -0,0 +1,57 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package tailscaleroot + +import ( + "os" + "os/exec" + "strings" + "testing" + + "tailscale.com/util/cibuild" +) + +// TestTsgoRevInCacheKey verifies that the Tailscale Go toolchain's git +// revision (from go.toolchain.rev) is blended into Go build cache keys. +// Without this, bumping the toolchain to a new commit that doesn't change +// the Go version number would silently reuse stale cached build artifacts. +// +// See https://github.com/tailscale/tailscale/issues/36589. +func TestTsgoRevInCacheKey(t *testing.T) { + goRoot := goEnv(t, "GOROOT") + isTsgo := strings.Contains(goRoot, "/.cache/tsgo/") + if !cibuild.On() && !isTsgo { + t.Skip("skipping; not in CI and not using the Tailscale Go toolchain") + } + + rev := strings.TrimSpace(GoToolchainRev) + if rev == "" { + t.Fatal("go.toolchain.rev is empty") + } + + // Build the small stdlib "errors" package with GODEBUG=gocachehash=1, + // which causes cmd/go to log its cache key computations to stderr. + cmd := exec.Command("go", "build", "errors") + cmd.Env = append(os.Environ(), "GODEBUG=gocachehash=1") + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("go build errors failed: %v\n%s", err, out) + } + + // The cache key output should contain the toolchain rev alongside the + // Go version, e.g.: + // HASH[moduleIndex]: "go1.26.2 dfe2a5fd8ee2e68b08ce5ff259269f50ecadf2f4" + if !strings.Contains(string(out), rev) { + t.Errorf("go.toolchain.rev %q not found in GODEBUG=gocachehash=1 output:\n%s", rev, out) + } +} + +func goEnv(t *testing.T, key string) string { + t.Helper() + out, err := exec.Command("go", "env", key).Output() + if err != nil { + t.Fatalf("go env %s: %v", key, err) + } + return strings.TrimSpace(string(out)) +} diff --git a/client/freedesktop/freedesktop_test.go b/client/freedesktop/freedesktop_test.go index 07a1104f36940..d02d1f67c286a 100644 --- a/client/freedesktop/freedesktop_test.go +++ b/client/freedesktop/freedesktop_test.go @@ -13,12 +13,12 @@ func TestEscape(t *testing.T) { name, input, want string }{ { - name: "no illegal chars", + name: "no-illegal-chars", input: "/home/user", want: "/home/user", }, { - name: "empty string", + name: "empty-string", input: "", want: "\"\"", }, @@ -38,12 +38,12 @@ func TestEscape(t *testing.T) { want: "\"\n\"", }, { - name: "double quote", + name: "double-quote", input: "\"", want: "\"\\\"\"", }, { - name: "single quote", + name: "single-quote", input: "'", want: "\"'\"", }, @@ -53,12 +53,12 @@ func TestEscape(t *testing.T) { want: "\"\\\\\"", }, { - name: "greater than", + name: "greater-than", input: ">", want: "\">\"", }, { - name: "less than", + name: "less-than", input: "<", want: "\"<\"", }, @@ -93,7 +93,7 @@ func TestEscape(t *testing.T) { want: "\"*\"", }, { - name: "question mark", + name: "question-mark", input: "?", want: "\"?\"", }, @@ -103,12 +103,12 @@ func TestEscape(t *testing.T) { want: "\"#\"", }, { - name: "open paren", + name: "open-paren", input: "(", want: "\"(\"", }, { - name: "close paren", + name: "close-paren", input: ")", want: "\")\"", }, @@ -118,17 +118,17 @@ func TestEscape(t *testing.T) { want: "\"\\`\"", }, { - name: "char without escape", + name: "char-without-escape", input: "/home/user\t", want: "\"/home/user\t\"", }, { - name: "char with escape", + name: "char-with-escape", input: "/home/user\\", want: "\"/home/user\\\\\"", }, { - name: "all illegal chars", + name: "all-illegal-chars", input: "/home/user" + needsEscape, want: "\"/home/user \t\n\\\"'\\\\><~|&;\\$*?#()\\`\"", }, diff --git a/client/local/local.go b/client/local/local.go index a7b8b83b10a77..5c75c0487b13b 100644 --- a/client/local/local.go +++ b/client/local/local.go @@ -192,8 +192,8 @@ func (e *AccessDeniedError) Unwrap() error { return e.err } // IsAccessDeniedError reports whether err is or wraps an AccessDeniedError. func IsAccessDeniedError(err error) bool { - var ae *AccessDeniedError - return errors.As(err, &ae) + _, ok := errors.AsType[*AccessDeniedError](err) + return ok } // PreconditionsFailedError is returned when the server responds @@ -210,8 +210,8 @@ func (e *PreconditionsFailedError) Unwrap() error { return e.err } // IsPreconditionsFailedError reports whether err is or wraps an PreconditionsFailedError. func IsPreconditionsFailedError(err error) bool { - var ae *PreconditionsFailedError - return errors.As(err, &ae) + _, ok := errors.AsType[*PreconditionsFailedError](err) + return ok } // bestError returns either err, or if body contains a valid JSON @@ -607,6 +607,24 @@ func (lc *Client) DebugResultJSON(ctx context.Context, action string) (any, erro return x, nil } +// GetDebugResultJSON invokes a debug action and decodes the JSON response +// into a value of type T. It avoids the marshal/unmarshal roundtrip that +// callers of [Client.DebugResultJSON] otherwise need to do to get a typed +// value. +// +// These are development tools and subject to change or removal over time. +func GetDebugResultJSON[T any](ctx context.Context, lc *Client, action string) (T, error) { + var v T + body, err := lc.send(ctx, "POST", "/localapi/v0/debug?action="+url.QueryEscape(action), 200, nil) + if err != nil { + return v, fmt.Errorf("error %w: %s", err, body) + } + if err := json.Unmarshal(body, &v); err != nil { + return v, err + } + return v, nil +} + // QueryOptionalFeatures queries the optional features supported by the Tailscale daemon. func (lc *Client) QueryOptionalFeatures(ctx context.Context) (*apitype.OptionalFeatures, error) { body, err := lc.send(ctx, "POST", "/localapi/v0/debug-optional-features", 200, nil) @@ -972,6 +990,19 @@ func (lc *Client) UserDial(ctx context.Context, network, host string, port uint1 if res.StatusCode != http.StatusSwitchingProtocols { body, _ := io.ReadAll(res.Body) res.Body.Close() + if res.StatusCode == http.StatusOK && res.Header.Get("Dial-Self") == "true" { + // Server told us to dial the address ourselves rather than + // proxying through the daemon. This happens for non-Tailscale + // addresses where the daemon shouldn't dial as root on the + // client's behalf. The server provides the resolved address + // to avoid a TOCTOU race with DNS re-resolution. + addr := res.Header.Get("Dial-Addr") + if addr == "" { + return nil, errors.New("server returned Dial-Self without Dial-Addr") + } + var d net.Dialer + return d.DialContext(ctx, network, addr) + } return nil, fmt.Errorf("unexpected HTTP response: %s, %s", res.Status, body) } // From here on, the underlying net.Conn is ours to use, but there @@ -1009,6 +1040,44 @@ func (lc *Client) CurrentDERPMap(ctx context.Context) (*tailcfg.DERPMap, error) return &derpMap, nil } +// CertDomains returns the list of domains for which the local tailscaled can +// fetch TLS certificates, equivalent to the DNS.CertDomains field of the +// current netmap. The returned list is sorted in ascending order, and is +// empty if no netmap has been received yet. +func (lc *Client) CertDomains(ctx context.Context) ([]string, error) { + body, err := lc.get200(ctx, "/localapi/v0/cert-domains") + if err != nil { + return nil, err + } + return decodeJSON[[]string](body) +} + +// DNSConfig returns the [tailcfg.DNSConfig] from the current netmap. +// It returns an error if no netmap has been received yet. +// It is intended for callers that need fields like ExtraRecords or CertDomains +// without pulling the rest of the netmap. +func (lc *Client) DNSConfig(ctx context.Context) (*tailcfg.DNSConfig, error) { + body, err := lc.get200(ctx, "/localapi/v0/dns-config") + if err != nil { + return nil, err + } + return decodeJSON[*tailcfg.DNSConfig](body) +} + +// PeerByID returns a peer's current full [tailcfg.Node] looked up by its +// [tailcfg.NodeID], in O(1) time on the daemon side. It returns an error +// if no peer with that NodeID is in the current netmap. +// +// It is intended for callers that need the latest state of a single peer +// without fetching the entire netmap. +func (lc *Client) PeerByID(ctx context.Context, id tailcfg.NodeID) (*tailcfg.Node, error) { + body, err := lc.get200(ctx, "/localapi/v0/peer-by-id?id="+strconv.FormatInt(int64(id), 10)) + if err != nil { + return nil, err + } + return decodeJSON[*tailcfg.Node](body) +} + // PingOpts contains options for the ping request. // // The zero value is valid, which means to use defaults. @@ -1071,7 +1140,7 @@ func tailscaledConnectHint() string { // ActiveState=inactive // SubState=dead st := map[string]string{} - for _, line := range strings.Split(string(out), "\n") { + for line := range strings.SplitSeq(string(out), "\n") { if k, v, ok := strings.Cut(line, "="); ok { st[k] = strings.TrimSpace(v) } @@ -1422,3 +1491,13 @@ func (lc *Client) GetAppConnectorRouteInfo(ctx context.Context) (appctype.RouteI } return decodeJSON[appctype.RouteInfo](body) } + +// GetServices returns the Services visible to this node, +// including their names, IP addresses, and ports, keyed by service name. +func (lc *Client) GetServices(ctx context.Context) (map[tailcfg.ServiceName]tailcfg.ServiceDetails, error) { + body, err := lc.get200(ctx, "/localapi/v0/services") + if err != nil { + return nil, err + } + return decodeJSON[map[tailcfg.ServiceName]tailcfg.ServiceDetails](body) +} diff --git a/client/local/local_test.go b/client/local/local_test.go index a5377fbd677a9..58a87b224564b 100644 --- a/client/local/local_test.go +++ b/client/local/local_test.go @@ -61,6 +61,57 @@ func TestWhoIsPeerNotFound(t *testing.T) { } } +func TestUserDialSelf(t *testing.T) { + // Start a real TCP listener that the client should dial directly + // when the server tells it to dial-self. + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + go func() { + for { + c, err := ln.Accept() + if err != nil { + return + } + c.Write([]byte("hello")) + c.Close() + } + }() + targetAddr := ln.Addr().(*net.TCPAddr) + + // Mock LocalAPI server that returns Dial-Self response. + nw := nettest.GetNetwork(t) + ts := nettest.NewHTTPServer(nw, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Dial-Self", "true") + w.Header().Set("Dial-Addr", targetAddr.String()) + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + + lc := &Client{ + Dial: func(ctx context.Context, network, addr string) (net.Conn, error) { + return nw.Dial(ctx, network, ts.Listener.Addr().String()) + }, + } + + conn, err := lc.UserDial(context.Background(), "tcp", targetAddr.IP.String(), uint16(targetAddr.Port)) + if err != nil { + t.Fatalf("UserDial: %v", err) + } + defer conn.Close() + + buf := make([]byte, 5) + n, err := conn.Read(buf) + if err != nil { + t.Fatalf("Read: %v", err) + } + if got := string(buf[:n]); got != "hello" { + t.Errorf("got %q, want %q", got, "hello") + } +} + func TestDeps(t *testing.T) { deptest.DepChecker{ BadDeps: map[string]string{ diff --git a/client/local/tailnetlock.go b/client/local/tailnetlock.go index 0084cb42e3ab0..5af90eb165102 100644 --- a/client/local/tailnetlock.go +++ b/client/local/tailnetlock.go @@ -28,8 +28,6 @@ func (lc *Client) NetworkLockStatus(ctx context.Context) (*ipnstate.NetworkLockS } // NetworkLockInit initializes the tailnet key authority. -// -// TODO(tom): Plumb through disablement secrets. func (lc *Client) NetworkLockInit(ctx context.Context, keys []tka.Key, disablementValues [][]byte, supportDisablement []byte) (*ipnstate.NetworkLockStatus, error) { var b bytes.Buffer type initRequest struct { diff --git a/client/systray/logo.go b/client/systray/logo.go index 4cd19778dc3a7..334cd7917cf59 100644 --- a/client/systray/logo.go +++ b/client/systray/logo.go @@ -11,6 +11,7 @@ import ( "image" "image/color" "image/png" + "log" "runtime" "sync" "time" @@ -204,12 +205,49 @@ var ( ) var ( - bg = color.NRGBA{0, 0, 0, 255} - fg = color.NRGBA{255, 255, 255, 255} - gray = color.NRGBA{255, 255, 255, 102} - red = color.NRGBA{229, 111, 74, 255} + black = color.NRGBA{0, 0, 0, 255} + white = color.NRGBA{255, 255, 255, 255} + darkGray = color.NRGBA{102, 102, 102, 255} + lightGray = color.NRGBA{153, 153, 153, 255} + red = color.NRGBA{229, 111, 74, 255} + transparent = color.NRGBA{} + + // default values to dark theme + bg = black + fg = white + gray = darkGray ) +// SetTheme sets the color theme of the systray icon. +// +// Supported themes are: +// - dark - white and gray dots over black background +// - dark:nobg - white and grey dots over transparent background +// - light - black and gray dots over white background +// - light:nobg - black and grey dots over transparent background +func SetTheme(theme string) { + switch theme { + case "dark": + bg = black + fg = white + gray = darkGray + case "dark:nobg": + bg = transparent + fg = white + gray = darkGray + case "light": + bg = white + fg = black + gray = lightGray + case "light:nobg": + bg = transparent + fg = black + gray = lightGray + default: + log.Printf("unknown theme: %q", theme) + } +} + // render returns a PNG image of the logo. func (logo tsLogo) render() *bytes.Buffer { const borderUnits = 1 @@ -233,8 +271,8 @@ func (logo tsLogo) renderWithBorder(borderUnits int) *bytes.Buffer { dc.InvertMask() } - for y := 0; y < 3; y++ { - for x := 0; x < 3; x++ { + for y := range 3 { + for x := range 3 { px := (borderUnits + 1 + 3*x) * radius py := (borderUnits + 1 + 3*y) * radius col := fg diff --git a/client/systray/startup-creator.go b/client/systray/startup-creator.go index 369190012ce6c..02a01809945e1 100644 --- a/client/systray/startup-creator.go +++ b/client/systray/startup-creator.go @@ -3,7 +3,6 @@ //go:build cgo || !darwin -// Package systray provides a minimal Tailscale systray application. package systray import ( diff --git a/client/systray/systray.go b/client/systray/systray.go index 65c1bec20a184..d0287e6470b06 100644 --- a/client/systray/systray.go +++ b/client/systray/systray.go @@ -621,11 +621,9 @@ func (menu *Menu) rebuildExitNodeMenu(ctx context.Context) { title += strings.Split(sugg.Name, ".")[0] } menu.exitNodes.AddSeparator() - rm := menu.exitNodes.AddSubMenuItemCheckbox(title, "", false) + active := recommendedIsActive(status, sugg.ID, sugg.Location.CountryCode(), sugg.Location.City()) + rm := menu.exitNodes.AddSubMenuItemCheckbox(title, "", active) setExitNodeOnClick(rm, sugg.ID) - if status.ExitNodeStatus != nil && sugg.ID == status.ExitNodeStatus.ID { - rm.Check() - } } } @@ -647,13 +645,11 @@ func (menu *Menu) rebuildExitNodeMenu(ctx context.Context) { if !ps.Online { name += " (offline)" } - sm := menu.exitNodes.AddSubMenuItemCheckbox(name, "", false) + active := status.ExitNodeStatus != nil && ps.ID == status.ExitNodeStatus.ID + sm := menu.exitNodes.AddSubMenuItemCheckbox(name, "", active) if !ps.Online { sm.Disable() } - if status.ExitNodeStatus != nil && ps.ID == status.ExitNodeStatus.ID { - sm.Check() - } setExitNodeOnClick(sm, ps.ID) } } @@ -743,6 +739,30 @@ func (mc *mvCountry) sortedCities() []*mvCity { return cities } +// recommendedIsActive reports whether the suggested exit node corresponds to +// the currently active exit node in status. +func recommendedIsActive(status *ipnstate.Status, suggID tailcfg.StableNodeID, suggCountry, suggCity string) bool { + if status == nil || status.ExitNodeStatus == nil || status.ExitNodeStatus.ID.IsZero() { + return false + } + if suggID == status.ExitNodeStatus.ID { + return true + } + if suggCountry == "" || suggCity == "" { + return false + } + for _, p := range status.Peer { + if p.ID != status.ExitNodeStatus.ID { + continue + } + if loc := p.Location; loc != nil && loc.CountryCode == suggCountry && loc.City == suggCity { + return true + } + return false + } + return false +} + // countryFlag takes a 2-character ASCII string and returns the corresponding emoji flag. // It returns the empty string on error. func countryFlag(code string) string { diff --git a/client/systray/systray_test.go b/client/systray/systray_test.go new file mode 100644 index 0000000000000..6b8ce8b95e540 --- /dev/null +++ b/client/systray/systray_test.go @@ -0,0 +1,120 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build cgo || !darwin + +package systray + +import ( + "testing" + + "tailscale.com/ipn/ipnstate" + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +func TestRecommendedIsActive(t *testing.T) { + t.Parallel() + + const ( + activeID = tailcfg.StableNodeID("active") + suggID = tailcfg.StableNodeID("suggestion") + ) + usNYC := &tailcfg.Location{CountryCode: "US", City: "New York"} + usCHI := &tailcfg.Location{CountryCode: "US", City: "Chicago"} + seSTO := &tailcfg.Location{CountryCode: "SE", City: "Stockholm"} + + statusWith := func(activePeer *ipnstate.PeerStatus) *ipnstate.Status { + s := &ipnstate.Status{ + ExitNodeStatus: &ipnstate.ExitNodeStatus{ID: activeID}, + } + if activePeer != nil { + s.Peer = map[key.NodePublic]*ipnstate.PeerStatus{{}: activePeer} + } + return s + } + + tests := []struct { + name string + status *ipnstate.Status + suggID tailcfg.StableNodeID + suggCountry string + suggCity string + isActive bool + }{ + { + name: "nil_status", + status: nil, + suggID: suggID, + }, + { + name: "no_exit_node", + status: &ipnstate.Status{}, + suggID: suggID, + }, + { + name: "exit_node_id_is_zero", + status: &ipnstate.Status{ExitNodeStatus: &ipnstate.ExitNodeStatus{}}, + suggID: suggID, + }, + { + name: "exact_id_match_short-circuits", + status: statusWith(&ipnstate.PeerStatus{ID: activeID, Location: usCHI}), + suggID: activeID, + suggCountry: "US", + suggCity: "New York", + isActive: true, + }, + { + name: "id_mismatch_but_same_city", + status: statusWith(&ipnstate.PeerStatus{ID: activeID, Location: usNYC}), + suggID: suggID, + suggCountry: "US", + suggCity: "New York", + isActive: true, + }, + { + name: "different_city", + status: statusWith(&ipnstate.PeerStatus{ID: activeID, Location: usCHI}), + suggID: suggID, + suggCountry: "US", + suggCity: "New York", + }, + { + name: "different_country", + status: statusWith(&ipnstate.PeerStatus{ID: activeID, Location: seSTO}), + suggID: suggID, + suggCountry: "US", + suggCity: "New York", + }, + { + name: "id_mismatch_suggestion_has_no_location", + status: statusWith(&ipnstate.PeerStatus{ID: activeID, Location: usNYC}), + suggID: suggID, + }, + { + name: "id_mismatch_active_peer_has_no_location", + status: statusWith(&ipnstate.PeerStatus{ID: activeID}), + suggID: suggID, + suggCountry: "US", + suggCity: "New York", + }, + { + name: "active_peer_not_in_status", + status: statusWith(nil), + suggID: suggID, + suggCountry: "US", + suggCity: "New York", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + isExitNodeActive := recommendedIsActive(tt.status, tt.suggID, tt.suggCountry, tt.suggCity) + if isExitNodeActive != tt.isActive { + t.Errorf("recommendedIsActive; got %v, want %v", isExitNodeActive, tt.isActive) + } + }) + } +} diff --git a/client/tailscale/tailscale_test.go b/client/tailscale/tailscale_test.go index fe2fbe383b679..342a2d7872026 100644 --- a/client/tailscale/tailscale_test.go +++ b/client/tailscale/tailscale_test.go @@ -31,7 +31,7 @@ func TestClientBuildURL(t *testing.T) { want: `http://127.0.0.1:1234/api/v2/tailnet/example%20dot%20com%3Ffoo=bar`, }, { - desc: "url.Values", + desc: "url-Values", elements: []any{"tailnet", "example.com", "acl", url.Values{"details": {"1"}}}, want: `http://127.0.0.1:1234/api/v2/tailnet/example.com/acl?details=1`, }, @@ -71,7 +71,7 @@ func TestClientBuildTailnetURL(t *testing.T) { want: `http://127.0.0.1:1234/api/v2/tailnet/example.com/foo%20bar%3Fbaz=qux`, }, { - desc: "url.Values", + desc: "url-Values", elements: []any{"acl", url.Values{"details": {"1"}}}, want: `http://127.0.0.1:1234/api/v2/tailnet/example.com/acl?details=1`, }, diff --git a/client/web/auth.go b/client/web/auth.go index 4e25b049b30ac..916f24782d55a 100644 --- a/client/web/auth.go +++ b/client/web/auth.go @@ -37,6 +37,7 @@ type browserSession struct { AuthURL string // from tailcfg.WebClientAuthResponse Created time.Time Authenticated bool + PendingAuth bool } // isAuthorized reports true if the given session is authorized @@ -172,12 +173,14 @@ func (s *Server) newSession(ctx context.Context, src *apitype.WhoIsResponse) (*b } session.AuthID = a.ID session.AuthURL = a.URL + session.PendingAuth = true } else { // control does not support check mode, so there is no additional auth we can do. session.Authenticated = true } s.browserSessions.Store(sid, session) + return session, nil } @@ -206,16 +209,24 @@ func (s *Server) awaitUserAuth(ctx context.Context, session *browserSession) err if session.isAuthorized(s.timeNow()) { return nil // already authorized } + a, err := s.waitAuthURL(ctx, session.AuthID, session.SrcNode) if err != nil { - // Clean up the session. Doing this on any error from control - // server to avoid the user getting stuck with a bad session - // cookie. + // Don't delete the session on context cancellation, as this is expected + // when users navigate away or refresh the page. + if errors.Is(err, context.Canceled) { + return err + } + + // Clean up the session for non-cancellation errors from control server + // to avoid the user getting stuck with a bad session cookie. s.browserSessions.Delete(session.ID) return err } + if a.Complete { session.Authenticated = a.Complete + session.PendingAuth = false s.browserSessions.Store(session.ID, session) } return nil diff --git a/client/web/src/api.ts b/client/web/src/api.ts index 246f74ff231c2..ea64742cdd339 100644 --- a/client/web/src/api.ts +++ b/client/web/src/api.ts @@ -123,7 +123,10 @@ export function useAPI() { return apiFetch<{ url?: string }>("/up", "POST", t.data) .then((d) => d.url && window.open(d.url, "_blank")) // "up" login step .then(() => incrementMetric("web_client_node_connect")) - .then(() => mutate("/data")) + .then(() => { + mutate("/data") + mutate("/auth") + }) .catch(handlePostError("Failed to login")) /** @@ -134,9 +137,9 @@ export function useAPI() { // For logout, must increment metric before running api call, // as tailscaled will be unreachable after the call completes. incrementMetric("web_client_node_disconnect") - return apiFetch("/local/v0/logout", "POST").catch( - handlePostError("Failed to logout") - ) + return apiFetch("/local/v0/logout", "POST") + .then(() => mutate("/auth")) + .catch(handlePostError("Failed to logout")) /** * "new-auth-session" handles creating a new check mode session to diff --git a/client/web/src/hooks/auth.ts b/client/web/src/hooks/auth.ts index c3d0cdc877022..c676647ca0b7e 100644 --- a/client/web/src/hooks/auth.ts +++ b/client/web/src/hooks/auth.ts @@ -3,6 +3,7 @@ import { useCallback, useEffect, useState } from "react" import { apiFetch, setSynoToken } from "src/api" +import useSWR from "swr" export type AuthResponse = { serverMode: AuthServerMode @@ -49,33 +50,26 @@ export function hasAnyEditCapabilities(auth: AuthResponse): boolean { * useAuth reports and refreshes Tailscale auth status for the web client. */ export default function useAuth() { - const [data, setData] = useState() - const [loading, setLoading] = useState(true) + const { data, error, mutate } = useSWR("/auth") const [ranSynoAuth, setRanSynoAuth] = useState(false) - const loadAuth = useCallback(() => { - setLoading(true) - return apiFetch("/auth", "GET") - .then((d) => { - setData(d) - if (d.needsSynoAuth) { - fetch("/webman/login.cgi") - .then((r) => r.json()) - .then((a) => { - setSynoToken(a.SynoToken) - setRanSynoAuth(true) - setLoading(false) - }) - } else { - setLoading(false) - } - return d - }) - .catch((error) => { - setLoading(false) - console.error(error) - }) - }, []) + const loading = !data && !error + + // Start Synology auth flow if needed. + useEffect(() => { + if (data?.needsSynoAuth && !ranSynoAuth) { + fetch("/webman/login.cgi") + .then((r) => r.json()) + .then((a) => { + setSynoToken(a.SynoToken) + setRanSynoAuth(true) + mutate() + }) + .catch((error) => { + console.error("Synology auth error:", error) + }) + } + }, [data?.needsSynoAuth, ranSynoAuth, mutate]) const newSession = useCallback(() => { return apiFetch<{ authUrl?: string }>("/auth/session/new", "GET") @@ -86,34 +80,26 @@ export default function useAuth() { } }) .then(() => { - loadAuth() + mutate() }) .catch((error) => { console.error(error) }) - }, [loadAuth]) + }, [mutate]) + // Start regular auth flow. useEffect(() => { - loadAuth().then((d) => { - if (!d) { - return - } - if ( - !d.authorized && - hasAnyEditCapabilities(d) && - // Start auth flow immediately if browser has requested it. - new URLSearchParams(window.location.search).get("check") === "now" - ) { - newSession() - } - }) - // eslint-disable-next-line react-hooks/exhaustive-deps - }, []) + const needsAuth = + data && + !loading && + !data.authorized && + hasAnyEditCapabilities(data) && + new URLSearchParams(window.location.search).get("check") === "now" - useEffect(() => { - loadAuth() // Refresh auth state after syno auth runs - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [ranSynoAuth]) + if (needsAuth) { + newSession() + } + }, [data, loading, newSession]) return { data, diff --git a/client/web/web.go b/client/web/web.go index f8a9e7c1769a2..95259ef1a9039 100644 --- a/client/web/web.go +++ b/client/web/web.go @@ -35,8 +35,10 @@ import ( "tailscale.com/net/netutil" "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" + "tailscale.com/tsweb" "tailscale.com/types/logger" "tailscale.com/types/views" + "tailscale.com/util/ctxkey" "tailscale.com/util/httpm" "tailscale.com/util/syspolicy/policyclient" "tailscale.com/version" @@ -527,45 +529,40 @@ func (s *Server) serveLoginAPI(w http.ResponseWriter, r *http.Request) { } } -type apiHandler[data any] struct { - s *Server - w http.ResponseWriter - r *http.Request - - // permissionCheck allows for defining whether a requesting peer's - // capabilities grant them access to make the given data update. - // If permissionCheck reports false, the request fails as unauthorized. - permissionCheck func(data data, peer peerCapabilities) bool -} - -// newHandler constructs a new api handler which restricts the given request -// to the specified permission check. If the permission check fails for -// the peer associated with the request, an unauthorized error is returned -// to the client. -func newHandler[data any](s *Server, w http.ResponseWriter, r *http.Request, permissionCheck func(data data, peer peerCapabilities) bool) *apiHandler[data] { - return &apiHandler[data]{ - s: s, - w: w, - r: r, - permissionCheck: permissionCheck, +// handleJSON manages decoding the request's body JSON as data and passing it +// on to the provided handler function. +func handleJSON[data any](h func(ctx context.Context, data data) error) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + var body data + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if err := h(r.Context(), body); err != nil { + if httpErr, ok := errors.AsType[tsweb.HTTPError](err); ok { + tsweb.WriteHTTPError(w, r, httpErr) + } else { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + return + } + w.WriteHeader(http.StatusOK) } } -// alwaysAllowed can be passed as the permissionCheck argument to newHandler -// for requests that are always allowed to complete regardless of a peer's -// capabilities. -func alwaysAllowed[data any](_ data, _ peerCapabilities) bool { return true } +var contextKeyPeer = ctxkey.New("peer-capabilities", peerCapabilities{}) -func (a *apiHandler[data]) getPeer() (peerCapabilities, error) { +func (s *Server) setPeer(r *http.Request) (*http.Request, error) { // TODO(tailscale/corp#16695,sonia): We also call StatusWithoutPeers and // WhoIs when originally checking for a session from authorizeRequest. // Would be nice if we could pipe those through to here so we don't end // up having to re-call them to grab the peer capabilities. - status, err := a.s.lc.StatusWithoutPeers(a.r.Context()) + status, err := s.lc.StatusWithoutPeers(r.Context()) if err != nil { return nil, err } - whois, err := a.s.lc.WhoIs(a.r.Context(), a.r.RemoteAddr) + whois, err := s.lc.WhoIs(r.Context(), r.RemoteAddr) if err != nil { return nil, err } @@ -573,56 +570,11 @@ func (a *apiHandler[data]) getPeer() (peerCapabilities, error) { if err != nil { return nil, err } - return peer, nil -} - -type noBodyData any // empty type, for use from serveAPI for endpoints with empty body - -// handle runs the given handler if the source peer satisfies the -// constraints for running this request. -// -// handle is expected for use when `data` type is empty, or set to -// `noBodyData` in practice. For requests that expect JSON body data -// to be attached, use handleJSON instead. -func (a *apiHandler[data]) handle(h http.HandlerFunc) { - peer, err := a.getPeer() - if err != nil { - http.Error(a.w, err.Error(), http.StatusInternalServerError) - return - } - var body data // not used - if !a.permissionCheck(body, peer) { - http.Error(a.w, "not allowed", http.StatusUnauthorized) - return - } - h(a.w, a.r) + return r.WithContext(contextKeyPeer.WithValue(r.Context(), peer)), nil } -// handleJSON manages decoding the request's body JSON and passing -// it on to the provided function if the source peer satisfies the -// constraints for running this request. -func (a *apiHandler[data]) handleJSON(h func(ctx context.Context, data data) error) { - defer a.r.Body.Close() - var body data - if err := json.NewDecoder(a.r.Body).Decode(&body); err != nil { - http.Error(a.w, err.Error(), http.StatusInternalServerError) - return - } - peer, err := a.getPeer() - if err != nil { - http.Error(a.w, err.Error(), http.StatusInternalServerError) - return - } - if !a.permissionCheck(body, peer) { - http.Error(a.w, "not allowed", http.StatusUnauthorized) - return - } - - if err := h(a.r.Context(), body); err != nil { - http.Error(a.w, err.Error(), http.StatusInternalServerError) - return - } - a.w.WriteHeader(http.StatusOK) +func (s *Server) getPeer(ctx context.Context) peerCapabilities { + return contextKeyPeer.Value(ctx) } // serveAPI serves requests for the web client api. @@ -637,67 +589,44 @@ func (s *Server) serveAPI(w http.ResponseWriter, r *http.Request) { } } + var err error + r, err = s.setPeer(r) + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + path := strings.TrimPrefix(r.URL.Path, "/api") switch { case path == "/data" && r.Method == httpm.GET: - newHandler[noBodyData](s, w, r, alwaysAllowed). - handle(s.serveGetNodeData) + s.serveGetNodeData(w, r) return case path == "/exit-nodes" && r.Method == httpm.GET: - newHandler[noBodyData](s, w, r, alwaysAllowed). - handle(s.serveGetExitNodes) + s.serveGetExitNodes(w, r) return case path == "/routes" && r.Method == httpm.POST: - peerAllowed := func(d postRoutesRequest, p peerCapabilities) bool { - if d.SetExitNode && !p.canEdit(capFeatureExitNodes) { - return false - } else if d.SetRoutes && !p.canEdit(capFeatureSubnets) { - return false - } - return true - } - newHandler[postRoutesRequest](s, w, r, peerAllowed). - handleJSON(s.servePostRoutes) + handleJSON[postRoutesRequest](s.servePostRoutes)(w, r) return case path == "/device-details-click" && r.Method == httpm.POST: - newHandler[noBodyData](s, w, r, alwaysAllowed). - handle(s.serveDeviceDetailsClick) + s.serveDeviceDetailsClick(w, r) return case path == "/local/v0/logout" && r.Method == httpm.POST: - peerAllowed := func(_ noBodyData, peer peerCapabilities) bool { - return peer.canEdit(capFeatureAccount) - } - newHandler[noBodyData](s, w, r, peerAllowed). - handle(s.proxyRequestToLocalAPI) + s.proxyRequestToLocalAPI(w, r) return case path == "/local/v0/prefs" && r.Method == httpm.PATCH: - peerAllowed := func(data maskedPrefs, peer peerCapabilities) bool { - if data.RunSSHSet && !peer.canEdit(capFeatureSSH) { - return false - } - return true - } - newHandler[maskedPrefs](s, w, r, peerAllowed). - handleJSON(s.serveUpdatePrefs) + handleJSON[maskedPrefs](s.serveUpdatePrefs)(w, r) return case path == "/local/v0/update/check" && r.Method == httpm.GET: - newHandler[noBodyData](s, w, r, alwaysAllowed). - handle(s.proxyRequestToLocalAPI) + s.proxyRequestToLocalAPI(w, r) return case path == "/local/v0/update/check" && r.Method == httpm.POST: - peerAllowed := func(_ noBodyData, peer peerCapabilities) bool { - return peer.canEdit(capFeatureAccount) - } - newHandler[noBodyData](s, w, r, peerAllowed). - handle(s.proxyRequestToLocalAPI) + s.proxyRequestToLocalAPI(w, r) return case path == "/local/v0/update/progress" && r.Method == httpm.POST: - newHandler[noBodyData](s, w, r, alwaysAllowed). - handle(s.proxyRequestToLocalAPI) + s.proxyRequestToLocalAPI(w, r) return case path == "/local/v0/upload-client-metrics" && r.Method == httpm.POST: - newHandler[noBodyData](s, w, r, alwaysAllowed). - handle(s.proxyRequestToLocalAPI) + s.proxyRequestToLocalAPI(w, r) return } http.Error(w, "invalid endpoint", http.StatusNotFound) @@ -771,6 +700,19 @@ func (s *Server) serveAPIAuth(w http.ResponseWriter, r *http.Request) { } } + // We might have a session for which we haven't awaited the result yet. + // This can happen when the AuthURL opens in the same browser tab instead + // of a new one due to browser settings. + // (See https://github.com/tailscale/tailscale/issues/11905) + // We therefore set a PendingAuth flag when creating a new session, check + // it here and call awaitUserAuth if we find it to be true. Once the auth + // wait completes, awaitUserAuth will set PendingAuth to false. + if sErr == nil && session.PendingAuth == true { + if err := s.awaitUserAuth(r.Context(), session); err != nil { + sErr = err + } + } + switch { case sErr != nil && errors.Is(sErr, errNotUsingTailscale): s.lc.IncrementCounter(r.Context(), "web_client_viewing_local", 1) @@ -1109,6 +1051,11 @@ type maskedPrefs struct { } func (s *Server) serveUpdatePrefs(ctx context.Context, prefs maskedPrefs) error { + peer := s.getPeer(ctx) + if prefs.RunSSHSet && !peer.canEdit(capFeatureSSH) { + return tsweb.Error(http.StatusUnauthorized, "RunSSHSet not allowed", nil) + } + _, err := s.lc.EditPrefs(ctx, &ipn.MaskedPrefs{ RunSSHSet: prefs.RunSSHSet, Prefs: ipn.Prefs{ @@ -1127,6 +1074,17 @@ type postRoutesRequest struct { } func (s *Server) servePostRoutes(ctx context.Context, data postRoutesRequest) error { + if !data.SetExitNode && !data.SetRoutes { + return tsweb.Error(http.StatusBadRequest, "must specify SetExitNode or SetRoutes", nil) + } + peer := s.getPeer(ctx) + if data.SetExitNode && !peer.canEdit(capFeatureExitNodes) { + return tsweb.Error(http.StatusUnauthorized, "SetExitNode not allowed", nil) + } + if data.SetRoutes && !peer.canEdit(capFeatureSubnets) { + return tsweb.Error(http.StatusUnauthorized, "SetRoutes not allowed", nil) + } + prefs, err := s.lc.GetPrefs(ctx) if err != nil { return err @@ -1140,13 +1098,14 @@ func (s *Server) servePostRoutes(ctx context.Context, data postRoutesRequest) er } currNonExitRoutes = append(currNonExitRoutes, r.String()) } - // Set non-edited fields to their current values. - if data.SetExitNode { - data.AdvertiseRoutes = currNonExitRoutes - } else if data.SetRoutes { + // For each group of fields not being set, preserve the current prefs. + if !data.SetExitNode { data.AdvertiseExitNode = currAdvertisingExitNode data.UseExitNode = prefs.ExitNodeID } + if !data.SetRoutes { + data.AdvertiseRoutes = currNonExitRoutes + } // Calculate routes. routesStr := strings.Join(data.AdvertiseRoutes, ",") @@ -1323,6 +1282,19 @@ func (s *Server) proxyRequestToLocalAPI(w http.ResponseWriter, r *http.Request) return } + switch path { + case "/v0/logout": + if !s.getPeer(r.Context()).canEdit(capFeatureAccount) { + http.Error(w, "not allowed", http.StatusUnauthorized) + return + } + case "/v0/update/check": + if r.Method == httpm.POST && !s.getPeer(r.Context()).canEdit(capFeatureAccount) { + http.Error(w, "not allowed", http.StatusUnauthorized) + return + } + } + localAPIURL := "http://" + apitype.LocalAPIHost + "/localapi" + path req, err := http.NewRequestWithContext(r.Context(), r.Method, localAPIURL, r.Body) if err != nil { diff --git a/client/web/web_test.go b/client/web/web_test.go index 6b9a51002b33b..51b6a8ac58781 100644 --- a/client/web/web_test.go +++ b/client/web/web_test.go @@ -41,37 +41,37 @@ func TestQnapAuthnURL(t *testing.T) { want string }{ { - name: "localhost http", + name: "localhost-http", in: "http://localhost:8088/", want: "http://localhost:8088/cgi-bin/authLogin.cgi?qtoken=token", }, { - name: "localhost https", + name: "localhost-https", in: "https://localhost:5000/", want: "https://localhost:5000/cgi-bin/authLogin.cgi?qtoken=token", }, { - name: "IP http", + name: "IP-http", in: "http://10.1.20.4:80/", want: "http://10.1.20.4:80/cgi-bin/authLogin.cgi?qtoken=token", }, { - name: "IP6 https", + name: "IP6-https", in: "https://[ff7d:0:1:2::1]/", want: "https://[ff7d:0:1:2::1]/cgi-bin/authLogin.cgi?qtoken=token", }, { - name: "hostname https", + name: "hostname-https", in: "https://qnap.example.com/", want: "https://qnap.example.com/cgi-bin/authLogin.cgi?qtoken=token", }, { - name: "invalid URL", + name: "invalid-URL", in: "This is not a URL, it is a really really really really really really really really really really really really long string to exercise the URL truncation code in the error path.", want: "http://localhost/cgi-bin/authLogin.cgi?qtoken=token", }, { - name: "err != nil", + name: "err-not-nil", in: "http://192.168.0.%31/", want: "http://localhost/cgi-bin/authLogin.cgi?qtoken=token", }, @@ -191,7 +191,7 @@ func TestServeAPI(t *testing.T) { reqBody: "{\"setExitNode\":true}", tests: []requestTest{{ remoteIP: remoteIPWithNoCapabilities, - wantResponse: "not allowed", + wantResponse: "SetExitNode not allowed", wantStatus: http.StatusUnauthorized, }, { remoteIP: remoteIPWithAllCapabilities, @@ -204,7 +204,7 @@ func TestServeAPI(t *testing.T) { reqContentType: "application/json", tests: []requestTest{{ remoteIP: remoteIPWithNoCapabilities, - wantResponse: "not allowed", + wantResponse: "RunSSHSet not allowed", wantStatus: http.StatusUnauthorized, }, { remoteIP: remoteIPWithAllCapabilities, @@ -582,12 +582,23 @@ func TestServeAuth(t *testing.T) { successCookie := "ts-cookie-success" s.browserSessions.Store(successCookie, &browserSession{ - ID: successCookie, - SrcNode: remoteNode.Node.ID, - SrcUser: user.ID, - Created: oneHourAgo, - AuthID: testAuthPathSuccess, - AuthURL: *testControlURL + testAuthPathSuccess, + ID: successCookie, + SrcNode: remoteNode.Node.ID, + SrcUser: user.ID, + Created: oneHourAgo, + AuthID: testAuthPathSuccess, + AuthURL: *testControlURL + testAuthPathSuccess, + PendingAuth: true, + }) + successCookie2 := "ts-cookie-success-2" + s.browserSessions.Store(successCookie2, &browserSession{ + ID: successCookie2, + SrcNode: remoteNode.Node.ID, + SrcUser: user.ID, + Created: oneHourAgo, + AuthID: testAuthPathSuccess, + AuthURL: *testControlURL + testAuthPathSuccess, + PendingAuth: true, }) failureCookie := "ts-cookie-failure" s.browserSessions.Store(failureCookie, &browserSession{ @@ -642,14 +653,15 @@ func TestServeAuth(t *testing.T) { AuthID: testAuthPath, AuthURL: *testControlURL + testAuthPath, Authenticated: false, + PendingAuth: true, }, }, { - name: "query-existing-incomplete-session", - path: "/api/auth", + name: "existing-session-used", + path: "/api/auth/session/new", // should not create new session cookie: successCookie, wantStatus: http.StatusOK, - wantResp: &authResponse{ViewerIdentity: vi, ServerMode: ManageServerMode}, + wantResp: &newSessionAuthResponse{AuthURL: *testControlURL + testAuthPathSuccess}, wantSession: &browserSession{ ID: successCookie, SrcNode: remoteNode.Node.ID, @@ -658,14 +670,15 @@ func TestServeAuth(t *testing.T) { AuthID: testAuthPathSuccess, AuthURL: *testControlURL + testAuthPathSuccess, Authenticated: false, + PendingAuth: true, }, }, { - name: "existing-session-used", - path: "/api/auth/session/new", // should not create new session + name: "transition-to-successful-session-via-api-auth-session-wait", + path: "/api/auth/session/wait", cookie: successCookie, wantStatus: http.StatusOK, - wantResp: &newSessionAuthResponse{AuthURL: *testControlURL + testAuthPathSuccess}, + wantResp: nil, wantSession: &browserSession{ ID: successCookie, SrcNode: remoteNode.Node.ID, @@ -673,17 +686,17 @@ func TestServeAuth(t *testing.T) { Created: oneHourAgo, AuthID: testAuthPathSuccess, AuthURL: *testControlURL + testAuthPathSuccess, - Authenticated: false, + Authenticated: true, }, }, { - name: "transition-to-successful-session", - path: "/api/auth/session/wait", - cookie: successCookie, + name: "transition-to-successful-session-via-api-auth", + path: "/api/auth", + cookie: successCookie2, wantStatus: http.StatusOK, - wantResp: nil, + wantResp: &authResponse{Authorized: true, ViewerIdentity: vi, ServerMode: ManageServerMode}, wantSession: &browserSession{ - ID: successCookie, + ID: successCookie2, SrcNode: remoteNode.Node.ID, SrcUser: user.ID, Created: oneHourAgo, @@ -731,6 +744,7 @@ func TestServeAuth(t *testing.T) { AuthID: testAuthPath, AuthURL: *testControlURL + testAuthPath, Authenticated: false, + PendingAuth: true, }, }, { @@ -748,6 +762,7 @@ func TestServeAuth(t *testing.T) { AuthID: testAuthPath, AuthURL: *testControlURL + testAuthPath, Authenticated: false, + PendingAuth: true, }, }, { @@ -1462,7 +1477,9 @@ func mockLocalAPI(t *testing.T, whoIs map[string]*apitype.WhoIsResponse, self fu http.Error(w, "invalid JSON body", http.StatusBadRequest) return } - metricCapture(metricNames[0].Name) + if metricCapture != nil && len(metricNames) > 0 { + metricCapture(metricNames[0].Name) + } writeJSON(w, struct{}{}) return case "/localapi/v0/logout": @@ -1501,47 +1518,47 @@ func TestCSRFProtect(t *testing.T) { wantError bool }{ { - name: "GET requests with no header are allowed", + name: "GET-no-header-allowed", // GET requests with no header are allowed method: "GET", }, { - name: "POST requests with same-origin are allowed", + name: "POST-same-origin-allowed", method: "POST", secFetchSite: "same-origin", }, { - name: "POST requests with cross-site are not allowed", + name: "POST-cross-site-rejected", method: "POST", secFetchSite: "cross-site", wantError: true, }, { - name: "POST requests with unknown sec-fetch-site values are not allowed", + name: "POST-unknown-sec-fetch-site-rejected", method: "POST", secFetchSite: "new-unknown-value", wantError: true, }, { - name: "POST requests with none are not allowed", + name: "POST-sec-fetch-none-rejected", method: "POST", secFetchSite: "none", wantError: true, }, { - name: "POST requests with no sec-fetch-site header but matching host and origin are allowed", + name: "POST-no-sec-fetch-site-matching-host-origin", // no sec-fetch-site header but matching host and origin are allowed method: "POST", host: "example.com", origin: "https://example.com", }, { - name: "POST requests with no sec-fetch-site and non-matching host and origin are not allowed", + name: "POST-no-sec-fetch-site-mismatched-host-origin-rejected", method: "POST", host: "example.com", origin: "https://example.net", wantError: true, }, { - name: "POST requests with no sec-fetch-site and and origin that matches the override are allowed", + name: "POST-no-sec-fetch-site-origin-override-allowed", method: "POST", originOverride: "example.net", host: "internal.example.foo", // Host can be changed by reverse proxies @@ -1587,3 +1604,149 @@ func TestCSRFProtect(t *testing.T) { }) } } + +func TestServePostRoutes(t *testing.T) { + existingExitNodeID := tailcfg.StableNodeID("existing-exit-node") + existingRoute := netip.MustParsePrefix("192.168.1.0/24") + + existingPrefs := &ipn.Prefs{ + ExitNodeID: existingExitNodeID, + AdvertiseRoutes: []netip.Prefix{existingRoute}, + } + + tests := []struct { + name string + data postRoutesRequest + peerCaps peerCapabilities + wantErr bool + wantEditPrefs bool // whether EditPrefs (PATCH /prefs) should be called + wantExitNodeID tailcfg.StableNodeID + wantRoutes []netip.Prefix + }{ + { + name: "empty-request", + data: postRoutesRequest{}, + peerCaps: peerCapabilities{capFeatureExitNodes: true, capFeatureSubnets: true}, + wantErr: true, + wantEditPrefs: false, + }, + { + name: "SetExitNode-only", + data: postRoutesRequest{ + SetExitNode: true, + UseExitNode: "new-exit-node", + }, + peerCaps: peerCapabilities{capFeatureExitNodes: true, capFeatureSubnets: true}, + wantEditPrefs: true, + wantExitNodeID: "new-exit-node", + wantRoutes: []netip.Prefix{existingRoute}, + }, + { + name: "SetExitNode-not-allowed", + data: postRoutesRequest{ + SetExitNode: true, + UseExitNode: "new-exit-node", + }, + peerCaps: peerCapabilities{capFeatureSubnets: true}, + wantErr: true, + }, + { + name: "SetRoutes-only", + data: postRoutesRequest{ + SetRoutes: true, + AdvertiseRoutes: []string{"10.0.0.0/8"}, + }, + peerCaps: peerCapabilities{capFeatureExitNodes: true, capFeatureSubnets: true}, + wantEditPrefs: true, + wantExitNodeID: existingExitNodeID, + wantRoutes: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, + }, + { + name: "SetRoutes-not-allowed", + data: postRoutesRequest{ + SetRoutes: true, + AdvertiseRoutes: []string{"10.0.0.0/8"}, + }, + peerCaps: peerCapabilities{capFeatureExitNodes: true}, + wantErr: true, + }, + { + name: "SetExitNode-and-SetRoutes", + data: postRoutesRequest{ + SetExitNode: true, + SetRoutes: true, + UseExitNode: "new-exit-node", + AdvertiseRoutes: []string{"10.0.0.0/8"}, + }, + peerCaps: peerCapabilities{capFeatureExitNodes: true, capFeatureSubnets: true}, + wantEditPrefs: true, + wantExitNodeID: "new-exit-node", + wantRoutes: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var gotPrefs *ipn.MaskedPrefs + + lal := memnet.Listen("local-tailscaled.sock:80") + defer lal.Close() + + localapi := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/localapi/v0/prefs" { + t.Errorf("unexpected localapi call to %q", r.URL.Path) + http.Error(w, "unexpected localapi call", http.StatusInternalServerError) + return + } + switch r.Method { + case httpm.GET: + writeJSON(w, existingPrefs) + case httpm.PATCH: + var mp ipn.MaskedPrefs + if err := json.NewDecoder(r.Body).Decode(&mp); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + gotPrefs = &mp + writeJSON(w, gotPrefs.Prefs) + default: + t.Errorf("unexpected method %q on /prefs", r.Method) + http.Error(w, "unexpected method", http.StatusMethodNotAllowed) + } + })} + defer localapi.Close() + go localapi.Serve(lal) + + s := &Server{ + mode: ManageServerMode, + lc: &local.Client{Dial: lal.Dial}, + } + + ctx := contextKeyPeer.WithValue(t.Context(), tt.peerCaps) + err := s.servePostRoutes(ctx, tt.data) + + if tt.wantErr { + if err == nil { + t.Error("wanted error, got nil") + } + if gotPrefs != nil { + t.Error("EditPrefs should not have been called on error") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if gotPrefs == nil { + t.Fatal("expected EditPrefs to be called") + } + if diff := cmp.Diff(tt.wantExitNodeID, gotPrefs.ExitNodeID); diff != "" { + t.Errorf("ExitNodeID mismatch (-want +got):\n%s", diff) + } + if diff := cmp.Diff(tt.wantRoutes, gotPrefs.AdvertiseRoutes, cmp.Comparer(func(a, b netip.Prefix) bool { return a.Compare(b) == 0 })); diff != "" { + t.Errorf("AdvertiseRoutes mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/clientupdate/clientupdate.go b/clientupdate/clientupdate.go index d52241483812a..6d034b342d1cf 100644 --- a/clientupdate/clientupdate.go +++ b/clientupdate/clientupdate.go @@ -1292,6 +1292,6 @@ func requireRoot() error { } func isExitError(err error) bool { - var exitErr *exec.ExitError - return errors.As(err, &exitErr) + _, ok := errors.AsType[*exec.ExitError](err) + return ok } diff --git a/clientupdate/clientupdate_test.go b/clientupdate/clientupdate_test.go index 13fc8f08a6a2e..8095151c8169c 100644 --- a/clientupdate/clientupdate_test.go +++ b/clientupdate/clientupdate_test.go @@ -148,27 +148,27 @@ func TestUpdateYUMRepoTrack(t *testing.T) { wantErr bool }{ { - desc: "same track", + desc: "same-track", before: YUMRepos[StableTrack], track: StableTrack, after: YUMRepos[StableTrack], }, { - desc: "change track", + desc: "change-track", before: YUMRepos[StableTrack], track: UnstableTrack, after: YUMRepos[UnstableTrack], rewrote: true, }, { - desc: "change track RC", + desc: "change-track-RC", before: YUMRepos[StableTrack], track: ReleaseCandidateTrack, after: YUMRepos[ReleaseCandidateTrack], rewrote: true, }, { - desc: "non-tailscale repo file", + desc: "non-tailscale-repo-file", before: YUMRepos["FakeRepo"], track: StableTrack, wantErr: true, @@ -215,7 +215,7 @@ func TestParseAlpinePackageVersion(t *testing.T) { wantErr bool }{ { - desc: "valid version", + desc: "valid-version", out: ` tailscale-1.44.2-r0 description: The easiest, most secure way to use WireGuard and 2FA @@ -229,7 +229,7 @@ tailscale-1.44.2-r0 installed size: want: "1.44.2", }, { - desc: "wrong package output", + desc: "wrong-package-output", out: ` busybox-1.36.1-r0 description: Size optimized toolbox of many common UNIX utilities @@ -243,7 +243,7 @@ busybox-1.36.1-r0 installed size: wantErr: true, }, { - desc: "missing version", + desc: "missing-version", out: ` tailscale description: The easiest, most secure way to use WireGuard and 2FA @@ -257,12 +257,12 @@ tailscale installed size: wantErr: true, }, { - desc: "empty output", + desc: "empty-output", out: "", wantErr: true, }, { - desc: "multiple versions", + desc: "multiple-versions", out: ` tailscale-1.54.1-r0 description: The easiest, most secure way to use WireGuard and 2FA @@ -322,14 +322,14 @@ func TestCheckOutdatedAlpineRepo(t *testing.T) { track string }{ { - name: "Up to date", + name: "up-to-date", fileContent: "https://dl-cdn.alpinelinux.org/alpine/v3.20/main", latestHTTPVersion: "1.95.3", latestApkVersion: "1.95.3", track: "unstable", }, { - name: "Behind unstable", + name: "behind-unstable", fileContent: "https://dl-cdn.alpinelinux.org/alpine/v3.20/main", latestHTTPVersion: "1.95.4", latestApkVersion: "1.95.3", @@ -339,7 +339,7 @@ func TestCheckOutdatedAlpineRepo(t *testing.T) { track: "unstable", }, { - name: "Behind stable", + name: "behind-stable", fileContent: "https://dl-cdn.alpinelinux.org/alpine/v2.40/main", latestHTTPVersion: "1.94.3", latestApkVersion: "1.92.1", @@ -349,7 +349,7 @@ func TestCheckOutdatedAlpineRepo(t *testing.T) { track: "stable", }, { - name: "Nothing in dist file", + name: "nothing-in-dist-file", fileContent: "", latestHTTPVersion: "1.94.3", latestApkVersion: "1.92.1", @@ -451,7 +451,7 @@ func TestSynoArch(t *testing.T) { synoinfoConfPath := filepath.Join(t.TempDir(), "synoinfo.conf") if err := os.WriteFile( synoinfoConfPath, - []byte(fmt.Sprintf("unique=%q\n", tt.synoinfoUnique)), + fmt.Appendf(nil, "unique=%q\n", tt.synoinfoUnique), 0600, ); err != nil { t.Fatal(err) @@ -505,14 +505,14 @@ unique=synology_88f6281_213air want: "88f6281", }, { - desc: "missing unique", + desc: "missing-unique", content: ` company_title="Synology" `, wantErr: true, }, { - desc: "empty unique", + desc: "empty-unique", content: ` company_title="Synology" unique= @@ -520,7 +520,7 @@ unique= wantErr: true, }, { - desc: "empty unique double-quoted", + desc: "empty-unique-double-quoted", content: ` company_title="Synology" unique="" @@ -528,7 +528,7 @@ unique="" wantErr: true, }, { - desc: "empty unique single-quoted", + desc: "empty-unique-single-quoted", content: ` company_title="Synology" unique='' @@ -536,7 +536,7 @@ unique='' wantErr: true, }, { - desc: "malformed unique", + desc: "malformed-unique", content: ` company_title="Synology" unique="synology_88f6281" @@ -544,12 +544,12 @@ unique="synology_88f6281" wantErr: true, }, { - desc: "empty file", + desc: "empty-file", content: ``, wantErr: true, }, { - desc: "empty lines and comments", + desc: "empty-lines-and-comments", content: ` # In a file named synoinfo? Shocking! @@ -613,7 +613,7 @@ func TestUnpackLinuxTarball(t *testing.T) { }, }, { - desc: "don't touch unrelated files", + desc: "skip-unrelated-files", // don't touch unrelated files before: map[string]string{ "tailscale": "v1", "tailscaled": "v1", @@ -645,7 +645,7 @@ func TestUnpackLinuxTarball(t *testing.T) { }, }, { - desc: "ignore extra tarball files", + desc: "ignore-extra-tarball-files", before: map[string]string{ "tailscale": "v1", "tailscaled": "v1", @@ -661,7 +661,7 @@ func TestUnpackLinuxTarball(t *testing.T) { }, }, { - desc: "tarball missing tailscaled", + desc: "tarball-missing-tailscaled", before: map[string]string{ "tailscale": "v1", "tailscaled": "v1", @@ -677,7 +677,7 @@ func TestUnpackLinuxTarball(t *testing.T) { wantErr: true, }, { - desc: "duplicate tailscale binary", + desc: "duplicate-tailscale-binary", before: map[string]string{ "tailscale": "v1", "tailscaled": "v1", @@ -696,7 +696,7 @@ func TestUnpackLinuxTarball(t *testing.T) { wantErr: true, }, { - desc: "empty archive", + desc: "empty-archive", before: map[string]string{ "tailscale": "v1", "tailscaled": "v1", @@ -952,17 +952,18 @@ func TestCleanupOldDownloads(t *testing.T) { func TestParseUnraidPluginVersion(t *testing.T) { tests := []struct { + name string plgPath string wantVer string wantErr string }{ - {plgPath: "testdata/tailscale-1.52.0.plg", wantVer: "1.52.0"}, - {plgPath: "testdata/tailscale-1.54.0.plg", wantVer: "1.54.0"}, - {plgPath: "testdata/tailscale-nover.plg", wantErr: "version not found in plg file"}, - {plgPath: "testdata/tailscale-nover-path-mentioned.plg", wantErr: "version not found in plg file"}, + {name: "v1_52_0", plgPath: "testdata/tailscale-1.52.0.plg", wantVer: "1.52.0"}, + {name: "v1_54_0", plgPath: "testdata/tailscale-1.54.0.plg", wantVer: "1.54.0"}, + {name: "nover", plgPath: "testdata/tailscale-nover.plg", wantErr: "version not found in plg file"}, + {name: "nover-path-mentioned", plgPath: "testdata/tailscale-nover-path-mentioned.plg", wantErr: "version not found in plg file"}, } for _, tt := range tests { - t.Run(tt.plgPath, func(t *testing.T) { + t.Run(tt.name, func(t *testing.T) { got, err := parseUnraidPluginVersion(tt.plgPath) if got != tt.wantVer { t.Errorf("got version: %q, want %q", got, tt.wantVer) @@ -992,7 +993,7 @@ func TestConfirm(t *testing.T) { want bool }{ { - desc: "on latest stable", + desc: "on-latest-stable", fromTrack: StableTrack, toTrack: StableTrack, fromVer: "1.66.0", @@ -1000,7 +1001,7 @@ func TestConfirm(t *testing.T) { want: false, }, { - desc: "stable upgrade", + desc: "stable-upgrade", fromTrack: StableTrack, toTrack: StableTrack, fromVer: "1.66.0", @@ -1008,7 +1009,7 @@ func TestConfirm(t *testing.T) { want: true, }, { - desc: "unstable upgrade", + desc: "unstable-upgrade", fromTrack: UnstableTrack, toTrack: UnstableTrack, fromVer: "1.67.1", @@ -1016,7 +1017,7 @@ func TestConfirm(t *testing.T) { want: true, }, { - desc: "from stable to unstable", + desc: "from-stable-to-unstable", fromTrack: StableTrack, toTrack: UnstableTrack, fromVer: "1.66.0", @@ -1024,7 +1025,7 @@ func TestConfirm(t *testing.T) { want: true, }, { - desc: "from unstable to stable", + desc: "from-unstable-to-stable", fromTrack: UnstableTrack, toTrack: StableTrack, fromVer: "1.67.1", @@ -1032,7 +1033,7 @@ func TestConfirm(t *testing.T) { want: true, }, { - desc: "confirm callback rejects", + desc: "confirm-callback-rejects", fromTrack: StableTrack, toTrack: StableTrack, fromVer: "1.66.0", @@ -1043,7 +1044,7 @@ func TestConfirm(t *testing.T) { want: false, }, { - desc: "confirm callback allows", + desc: "confirm-callback-allows", fromTrack: StableTrack, toTrack: StableTrack, fromVer: "1.66.0", diff --git a/clientupdate/clientupdate_windows.go b/clientupdate/clientupdate_windows.go index 70a3c509121ea..50b77c38b4e5a 100644 --- a/clientupdate/clientupdate_windows.go +++ b/clientupdate/clientupdate_windows.go @@ -38,12 +38,12 @@ const ( updaterPrefix = "tailscale-updater" ) -func makeSelfCopy() (origPathExe, tmpPathExe string, err error) { - selfExe, err := os.Executable() +func makeCmdTailscaleCopy() (origPathExe, tmpPathExe string, err error) { + srcExe, err := findCmdTailscale() if err != nil { return "", "", err } - f, err := os.Open(selfExe) + f, err := os.Open(srcExe) if err != nil { return "", "", err } @@ -59,7 +59,25 @@ func makeSelfCopy() (origPathExe, tmpPathExe string, err error) { f2.Close() return "", "", err } - return selfExe, f2.Name(), f2.Close() + return srcExe, f2.Name(), f2.Close() +} + +// findCmdTailscale returns the path to the binary that should be copied for the update +// re-execution. The copy is re-executed with "update" as a subcommand, so it must be +// a binary that handles "update" (ie tailscale.exe, not tailscaled.exe) +func findCmdTailscale() (string, error) { + selfExe, err := os.Executable() + if err != nil { + return "", err + } + if strings.EqualFold(filepath.Base(selfExe), "tailscale.exe") { + return selfExe, nil + } + ts := filepath.Join(filepath.Dir(selfExe), "tailscale.exe") + if _, err := os.Stat(ts); err != nil { + return "", fmt.Errorf("cannot find tailscale.exe alongside %s: %w", selfExe, err) + } + return ts, nil } func markTempFileWindows(name string) error { @@ -159,14 +177,14 @@ you can run the command prompt as Administrator one of these ways: up.Logf("making tailscale.exe copy to switch to...") up.cleanupOldDownloads(filepath.Join(os.TempDir(), updaterPrefix+"-*.exe")) - _, selfCopy, err := makeSelfCopy() + _, cmdTailscaleCopy, err := makeCmdTailscaleCopy() if err != nil { return err } - defer os.Remove(selfCopy) + defer os.Remove(cmdTailscaleCopy) up.Logf("running tailscale.exe copy for final install...") - cmd := exec.Command(selfCopy, "update") + cmd := exec.Command(cmdTailscaleCopy, "update") cmd.Env = append(os.Environ(), winMSIEnv+"="+msiTarget, winVersionEnv+"="+ver) cmd.Stdout = up.Stderr cmd.Stderr = up.Stderr diff --git a/clientupdate/distsign/distsign_test.go b/clientupdate/distsign/distsign_test.go index 0d454771fc9a4..1380078859f3a 100644 --- a/clientupdate/distsign/distsign_test.go +++ b/clientupdate/distsign/distsign_test.go @@ -30,7 +30,7 @@ func TestDownload(t *testing.T) { wantErr bool }{ { - desc: "missing file", + desc: "missing-file", before: func(*testing.T) {}, src: "hello", wantErr: true, @@ -44,7 +44,7 @@ func TestDownload(t *testing.T) { want: []byte("world"), }, { - desc: "no signature", + desc: "no-signature", before: func(*testing.T) { srv.add("hello", []byte("world")) }, @@ -52,7 +52,7 @@ func TestDownload(t *testing.T) { wantErr: true, }, { - desc: "bad signature", + desc: "bad-signature", before: func(*testing.T) { srv.add("hello", []byte("world")) srv.add("hello.sig", []byte("potato")) @@ -61,7 +61,7 @@ func TestDownload(t *testing.T) { wantErr: true, }, { - desc: "signed with untrusted key", + desc: "signed-untrusted-key", before: func(t *testing.T) { srv.add("hello", []byte("world")) srv.add("hello.sig", newSigningKeyPair(t).sign([]byte("world"))) @@ -70,7 +70,7 @@ func TestDownload(t *testing.T) { wantErr: true, }, { - desc: "signed with root key", + desc: "signed-with-root-key", before: func(t *testing.T) { srv.add("hello", []byte("world")) srv.add("hello.sig", ed25519.Sign(srv.roots[0].k, []byte("world"))) @@ -79,7 +79,7 @@ func TestDownload(t *testing.T) { wantErr: true, }, { - desc: "bad signing key signature", + desc: "bad-signing-key-signature", before: func(t *testing.T) { srv.add("distsign.pub.sig", []byte("potato")) srv.addSigned("hello", []byte("world")) @@ -130,7 +130,7 @@ func TestValidateLocalBinary(t *testing.T) { wantErr bool }{ { - desc: "missing file", + desc: "missing-file", before: func(*testing.T) {}, src: "hello", wantErr: true, @@ -143,7 +143,7 @@ func TestValidateLocalBinary(t *testing.T) { src: "hello", }, { - desc: "contents changed", + desc: "contents-changed", before: func(*testing.T) { srv.addSigned("hello", []byte("new world")) }, @@ -151,7 +151,7 @@ func TestValidateLocalBinary(t *testing.T) { wantErr: true, }, { - desc: "no signature", + desc: "no-signature", before: func(*testing.T) { srv.add("hello", []byte("world")) }, @@ -159,7 +159,7 @@ func TestValidateLocalBinary(t *testing.T) { wantErr: true, }, { - desc: "bad signature", + desc: "bad-signature", before: func(*testing.T) { srv.add("hello", []byte("world")) srv.add("hello.sig", []byte("potato")) @@ -168,7 +168,7 @@ func TestValidateLocalBinary(t *testing.T) { wantErr: true, }, { - desc: "signed with untrusted key", + desc: "signed-untrusted-key", before: func(t *testing.T) { srv.add("hello", []byte("world")) srv.add("hello.sig", newSigningKeyPair(t).sign([]byte("world"))) @@ -177,7 +177,7 @@ func TestValidateLocalBinary(t *testing.T) { wantErr: true, }, { - desc: "signed with root key", + desc: "signed-with-root-key", before: func(t *testing.T) { srv.add("hello", []byte("world")) srv.add("hello.sig", ed25519.Sign(srv.roots[0].k, []byte("world"))) @@ -186,7 +186,7 @@ func TestValidateLocalBinary(t *testing.T) { wantErr: true, }, { - desc: "bad signing key signature", + desc: "bad-signing-key-signature", before: func(t *testing.T) { srv.add("distsign.pub.sig", []byte("potato")) srv.addSigned("hello", []byte("world")) @@ -341,7 +341,7 @@ func TestParseRootKey(t *testing.T) { wantErr: true, }, { - desc: "invalid PEM tag", + desc: "invalid-PEM-tag", generate: func() ([]byte, []byte, error) { priv, pub, err := GenerateRootKey() priv = bytes.Replace(priv, []byte("ROOT "), nil, -1) @@ -350,7 +350,7 @@ func TestParseRootKey(t *testing.T) { wantErr: true, }, { - desc: "not PEM", + desc: "not-PEM", generate: func() ([]byte, []byte, error) { return []byte("s3cr3t"), nil, nil }, wantErr: true, }, @@ -399,7 +399,7 @@ func TestParseSigningKey(t *testing.T) { wantErr: true, }, { - desc: "invalid PEM tag", + desc: "invalid-PEM-tag", generate: func() ([]byte, []byte, error) { priv, pub, err := GenerateSigningKey() priv = bytes.Replace(priv, []byte("SIGNING "), nil, -1) @@ -408,7 +408,7 @@ func TestParseSigningKey(t *testing.T) { wantErr: true, }, { - desc: "not PEM", + desc: "not-PEM", generate: func() ([]byte, []byte, error) { return []byte("s3cr3t"), nil, nil }, wantErr: true, }, diff --git a/cmd/cloner/cloner.go b/cmd/cloner/cloner.go index a3f0684faa589..8b4cacf7a8849 100644 --- a/cmd/cloner/cloner.go +++ b/cmd/cloner/cloner.go @@ -129,6 +129,12 @@ func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) { } continue } + // Named types with basic underlying types (map/slice) that + // have their own Clone method should use it directly. + if methodResultType(ft, "Clone") != nil { + writef("dst.%s = src.%s.Clone()", fname, fname) + continue + } } switch ft := ft.Underlying().(type) { case *types.Slice: @@ -137,27 +143,9 @@ func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) { writef("if src.%s != nil {", fname) writef("dst.%s = make([]%s, len(src.%s))", fname, n, fname) writef("for i := range dst.%s {", fname) - if ptr, isPtr := ft.Elem().(*types.Pointer); isPtr { - writef("if src.%s[i] == nil { dst.%s[i] = nil } else {", fname, fname) - if codegen.ContainsPointers(ptr.Elem()) { - if _, isIface := ptr.Elem().Underlying().(*types.Interface); isIface { - it.Import("", "tailscale.com/types/ptr") - writef("\tdst.%s[i] = ptr.To((*src.%s[i]).Clone())", fname, fname) - } else { - writef("\tdst.%s[i] = src.%s[i].Clone()", fname, fname) - } - } else { - it.Import("", "tailscale.com/types/ptr") - writef("\tdst.%s[i] = ptr.To(*src.%s[i])", fname, fname) - } - writef("}") - } else if ft.Elem().String() == "encoding/json.RawMessage" { - writef("\tdst.%s[i] = append(src.%s[i][:0:0], src.%s[i]...)", fname, fname, fname) - } else if _, isIface := ft.Elem().Underlying().(*types.Interface); isIface { - writef("\tdst.%s[i] = src.%s[i].Clone()", fname, fname) - } else { - writef("\tdst.%s[i] = *src.%s[i].Clone()", fname, fname) - } + writeSliceElemClone(writef, ft.Elem(), + fmt.Sprintf("src.%s[i]", fname), + fmt.Sprintf("dst.%s[i]", fname)) writef("}") writef("}") } else { @@ -170,12 +158,11 @@ func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) { writef("dst.%s = src.%s.Clone()", fname, fname) continue } - it.Import("", "tailscale.com/types/ptr") writef("if dst.%s != nil {", fname) if _, isIface := base.Underlying().(*types.Interface); isIface && hasPtrs { - writef("\tdst.%s = ptr.To((*src.%s).Clone())", fname, fname) + writef("\tdst.%s = new((*src.%s).Clone())", fname, fname) } else if !hasPtrs { - writef("\tdst.%s = ptr.To(*src.%s)", fname, fname) + writef("\tdst.%s = new(*src.%s)", fname, fname) } else { writef("\t" + `panic("TODO pointers in pointers")`) } @@ -186,11 +173,28 @@ func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) { n := it.QualifiedName(sliceType.Elem()) writef("if dst.%s != nil {", fname) writef("\tdst.%s = map[%s]%s{}", fname, it.QualifiedName(ft.Key()), it.QualifiedName(elem)) - writef("\tfor k := range src.%s {", fname) - // use zero-length slice instead of nil to ensure - // the key is always copied. - writef("\t\tdst.%s[k] = append([]%s{}, src.%s[k]...)", fname, n, fname) - writef("\t}") + if codegen.ContainsPointers(sliceType.Elem()) { + writef("\tfor k, sv := range src.%s {", fname) + writef("\t\tif sv == nil {") + writef("\t\t\tdst.%s[k] = nil", fname) + writef("\t\t\tcontinue") + writef("\t\t}") + writef("\t\tdst.%s[k] = make([]%s, len(sv))", fname, n) + writef("\t\tfor i := range sv {") + innerWritef := func(format string, args ...any) { + writef("\t\t"+format, args...) + } + writeSliceElemClone(innerWritef, sliceType.Elem(), + "sv[i]", fmt.Sprintf("dst.%s[k][i]", fname)) + writef("\t\t}") + writef("\t}") + } else { + writef("\tfor k := range src.%s {", fname) + // use zero-length slice instead of nil to ensure + // the key is always copied. + writef("\t\tdst.%s[k] = append([]%s{}, src.%s[k]...)", fname, n, fname) + writef("\t}") + } writef("}") } else if codegen.IsViewType(elem) || !codegen.ContainsPointers(elem) { // If the map values are view types (which are @@ -239,6 +243,31 @@ func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) { buf.Write(codegen.AssertStructUnchanged(t, name, typeParams, "Clone", it)) } +// writeSliceElemClone generates code to deep-clone a single slice element +// from srcExpr to dstExpr. It handles pointer, json.RawMessage, interface, +// and named struct element types. +func writeSliceElemClone(writef func(string, ...any), elemType types.Type, srcExpr, dstExpr string) { + if ptr, isPtr := elemType.(*types.Pointer); isPtr { + writef("if %s == nil { %s = nil } else {", srcExpr, dstExpr) + if codegen.ContainsPointers(ptr.Elem()) { + if _, isIface := ptr.Elem().Underlying().(*types.Interface); isIface { + writef("\t%s = new((*%s).Clone())", dstExpr, srcExpr) + } else { + writef("\t%s = %s.Clone()", dstExpr, srcExpr) + } + } else { + writef("\t%s = new(*%s)", dstExpr, srcExpr) + } + writef("}") + } else if elemType.String() == "encoding/json.RawMessage" { + writef("%s = append(%s[:0:0], %s...)", dstExpr, srcExpr, srcExpr) + } else if _, isIface := elemType.Underlying().(*types.Interface); isIface { + writef("%s = %s.Clone()", dstExpr, srcExpr) + } else { + writef("%s = *%s.Clone()", dstExpr, srcExpr) + } +} + // hasBasicUnderlying reports true when typ.Underlying() is a slice or a map. func hasBasicUnderlying(typ types.Type) bool { switch typ.Underlying().(type) { @@ -293,14 +322,12 @@ func writeMapValueClone(params mapValueCloneParams) { writef("if %s == nil { %s = nil } else {", params.SrcExpr, params.DstExpr) if base := elem.Elem().Underlying(); codegen.ContainsPointers(base) { if _, isIface := base.(*types.Interface); isIface { - params.It.Import("", "tailscale.com/types/ptr") - writef("\t%s = ptr.To((*%s).Clone())", params.DstExpr, params.SrcExpr) + writef("\t%s = new((*%s).Clone())", params.DstExpr, params.SrcExpr) } else { writef("\t%s = %s.Clone()", params.DstExpr, params.SrcExpr) } } else { - params.It.Import("", "tailscale.com/types/ptr") - writef("\t%s = ptr.To(*%s)", params.DstExpr, params.SrcExpr) + writef("\t%s = new(*%s)", params.DstExpr, params.SrcExpr) } writef("}") diff --git a/cmd/cloner/cloner_test.go b/cmd/cloner/cloner_test.go index b06f5c4fa5610..f8beb4a88b952 100644 --- a/cmd/cloner/cloner_test.go +++ b/cmd/cloner/cloner_test.go @@ -7,6 +7,7 @@ import ( "reflect" "testing" + "github.com/google/go-cmp/cmp" "tailscale.com/cmd/cloner/clonerex" ) @@ -154,6 +155,74 @@ func TestMapWithPointers(t *testing.T) { } } +func TestNamedMapContainer(t *testing.T) { + orig := &clonerex.NamedMapContainer{ + Attrs: clonerex.NamedMap{ + "str": "hello", + "num": int64(42), + "bool": true, + }, + } + + cloned := orig.Clone() + if !reflect.DeepEqual(orig, cloned) { + t.Errorf("Clone() = %v, want %v", cloned, orig) + } + + // Mutate the cloned map to verify no aliasing. + cloned.Attrs["str"] = "modified" + if orig.Attrs["str"] == "modified" { + t.Errorf("Clone() aliased memory in Attrs: original was modified") + } + + // Verify nil handling. + nilContainer := &clonerex.NamedMapContainer{} + nilClone := nilContainer.Clone() + if !reflect.DeepEqual(nilContainer, nilClone) { + t.Errorf("Clone() of nil Attrs = %v, want %v", nilClone, nilContainer) + } +} + +func TestMapSlicePointerContainer(t *testing.T) { + num := 42 + orig := &clonerex.MapSlicePointerContainer{ + Routes: map[string][]*clonerex.SliceContainer{ + "route1": { + {Slice: []*int{&num}}, + {Slice: []*int{&num, &num}}, + }, + "route2": { + {Slice: []*int{&num}}, + }, + }, + } + + cloned := orig.Clone() + if !reflect.DeepEqual(orig, cloned) { + t.Errorf("Clone() = %v, want %v", cloned, orig) + } + + // Mutate cloned.Routes pointer values + *cloned.Routes["route1"][0].Slice[0] = 999 + if *orig.Routes["route1"][0].Slice[0] == 999 { + t.Errorf("Clone() aliased memory in Routes: original was modified") + } +} + +func TestMapSlicePointerContainerNilValue(t *testing.T) { + num := 7 + orig := &clonerex.MapSlicePointerContainer{ + Routes: map[string][]*clonerex.SliceContainer{ + "nil-value": nil, + "non-nil": {{Slice: []*int{&num}}}, + }, + } + cloned := orig.Clone() + if diff := cmp.Diff(orig.Routes, cloned.Routes); diff != "" { + t.Errorf("Clone() Routes mismatch (-orig +cloned):\n%s", diff) + } +} + func TestDeeplyNestedMap(t *testing.T) { num := 123 orig := &clonerex.DeeplyNestedMap{ diff --git a/cmd/cloner/clonerex/clonerex.go b/cmd/cloner/clonerex/clonerex.go index 1007d0c6b646d..41626d3ae8b45 100644 --- a/cmd/cloner/clonerex/clonerex.go +++ b/cmd/cloner/clonerex/clonerex.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & contributors // SPDX-License-Identifier: BSD-3-Clause -//go:generate go run tailscale.com/cmd/cloner -clonefunc=true -type SliceContainer,InterfaceContainer,MapWithPointers,DeeplyNestedMap +//go:generate go run tailscale.com/cmd/cloner -clonefunc=true -type SliceContainer,InterfaceContainer,MapWithPointers,DeeplyNestedMap,NamedMapContainer,MapSlicePointerContainer // Package clonerex is an example package for the cloner tool. package clonerex @@ -39,6 +39,34 @@ type MapWithPointers struct { CloneInterface map[string]Cloneable } +// NamedMap is a named map type with its own Clone method. +// This tests that the cloner uses the type's Clone method +// rather than trying to descend into the map's value type. +type NamedMap map[string]any + +func (m NamedMap) Clone() NamedMap { + if m == nil { + return nil + } + m2 := make(NamedMap, len(m)) + for k, v := range m { + m2[k] = v + } + return m2 +} + +// NamedMapContainer has a field whose type is a named map with a Clone method. +type NamedMapContainer struct { + Attrs NamedMap +} + +// MapSlicePointerContainer has a map whose values are slices of pointers. +// This tests that the cloner deep-clones the pointer elements in the slice, +// not just the slice itself (which would leave aliased pointers). +type MapSlicePointerContainer struct { + Routes map[string][]*SliceContainer +} + // DeeplyNestedMap tests arbitrary depth of map nesting (3+ levels) type DeeplyNestedMap struct { ThreeLevels map[string]map[string]map[string]int diff --git a/cmd/cloner/clonerex/clonerex_clone.go b/cmd/cloner/clonerex/clonerex_clone.go index 5c161239fc992..9a4413177bb47 100644 --- a/cmd/cloner/clonerex/clonerex_clone.go +++ b/cmd/cloner/clonerex/clonerex_clone.go @@ -7,8 +7,6 @@ package clonerex import ( "maps" - - "tailscale.com/types/ptr" ) // Clone makes a deep copy of SliceContainer. @@ -25,7 +23,7 @@ func (src *SliceContainer) Clone() *SliceContainer { if src.Slice[i] == nil { dst.Slice[i] = nil } else { - dst.Slice[i] = ptr.To(*src.Slice[i]) + dst.Slice[i] = new(*src.Slice[i]) } } } @@ -70,7 +68,7 @@ func (src *MapWithPointers) Clone() *MapWithPointers { if v == nil { dst.Nested[k] = nil } else { - dst.Nested[k] = ptr.To(*v) + dst.Nested[k] = new(*v) } } } @@ -161,9 +159,59 @@ var _DeeplyNestedMapCloneNeedsRegeneration = DeeplyNestedMap(struct { FourLevels map[string]map[string]map[string]map[string]*SliceContainer }{}) +// Clone makes a deep copy of NamedMapContainer. +// The result aliases no memory with the original. +func (src *NamedMapContainer) Clone() *NamedMapContainer { + if src == nil { + return nil + } + dst := new(NamedMapContainer) + *dst = *src + dst.Attrs = src.Attrs.Clone() + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _NamedMapContainerCloneNeedsRegeneration = NamedMapContainer(struct { + Attrs NamedMap +}{}) + +// Clone makes a deep copy of MapSlicePointerContainer. +// The result aliases no memory with the original. +func (src *MapSlicePointerContainer) Clone() *MapSlicePointerContainer { + if src == nil { + return nil + } + dst := new(MapSlicePointerContainer) + *dst = *src + if dst.Routes != nil { + dst.Routes = map[string][]*SliceContainer{} + for k, sv := range src.Routes { + if sv == nil { + dst.Routes[k] = nil + continue + } + dst.Routes[k] = make([]*SliceContainer, len(sv)) + for i := range sv { + if sv[i] == nil { + dst.Routes[k][i] = nil + } else { + dst.Routes[k][i] = sv[i].Clone() + } + } + } + } + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _MapSlicePointerContainerCloneNeedsRegeneration = MapSlicePointerContainer(struct { + Routes map[string][]*SliceContainer +}{}) + // Clone duplicates src into dst and reports whether it succeeded. // To succeed, must be of types <*T, *T> or <*T, **T>, -// where T is one of SliceContainer,InterfaceContainer,MapWithPointers,DeeplyNestedMap. +// where T is one of SliceContainer,InterfaceContainer,MapWithPointers,DeeplyNestedMap,NamedMapContainer,MapSlicePointerContainer. func Clone(dst, src any) bool { switch src := src.(type) { case *SliceContainer: @@ -202,6 +250,24 @@ func Clone(dst, src any) bool { *dst = src.Clone() return true } + case *NamedMapContainer: + switch dst := dst.(type) { + case *NamedMapContainer: + *dst = *src.Clone() + return true + case **NamedMapContainer: + *dst = src.Clone() + return true + } + case *MapSlicePointerContainer: + switch dst := dst.(type) { + case *MapSlicePointerContainer: + *dst = *src.Clone() + return true + case **MapSlicePointerContainer: + *dst = src.Clone() + return true + } } return false } diff --git a/cmd/containerboot/egressservices.go b/cmd/containerboot/egressservices.go index e60d65c047f95..79fb77c8569ed 100644 --- a/cmd/containerboot/egressservices.go +++ b/cmd/containerboot/egressservices.go @@ -22,11 +22,12 @@ import ( "time" "github.com/fsnotify/fsnotify" + "tailscale.com/client/local" - "tailscale.com/ipn" "tailscale.com/kube/egressservices" "tailscale.com/kube/kubeclient" "tailscale.com/kube/kubetypes" + "tailscale.com/types/netmap" "tailscale.com/util/httpm" "tailscale.com/util/linuxfw" "tailscale.com/util/mak" @@ -54,7 +55,7 @@ type egressProxy struct { tsClient *local.Client // never nil - netmapChan chan ipn.Notify // chan to receive netmap updates on + netmapChan chan *netmap.NetworkMap // chan to receive netmap updates on podIPv4 string // never empty string, currently only IPv4 is supported @@ -86,7 +87,7 @@ type httpClient interface { // - the mounted egress config has changed // - the proxy's tailnet IP addresses have changed // - tailnet IPs have changed for any backend targets specified by tailnet FQDN -func (ep *egressProxy) run(ctx context.Context, n ipn.Notify, opts egressProxyRunOpts) error { +func (ep *egressProxy) run(ctx context.Context, nm *netmap.NetworkMap, opts egressProxyRunOpts) error { ep.configure(opts) var tickChan <-chan time.Time var eventChan <-chan fsnotify.Event @@ -105,7 +106,7 @@ func (ep *egressProxy) run(ctx context.Context, n ipn.Notify, opts egressProxyRu eventChan = w.Events } - if err := ep.sync(ctx, n); err != nil { + if err := ep.sync(ctx, nm); err != nil { return err } for { @@ -116,14 +117,14 @@ func (ep *egressProxy) run(ctx context.Context, n ipn.Notify, opts egressProxyRu log.Printf("periodic sync, ensuring firewall config is up to date...") case <-eventChan: log.Printf("config file change detected, ensuring firewall config is up to date...") - case n = <-ep.netmapChan: - shouldResync := ep.shouldResync(n) + case nm = <-ep.netmapChan: + shouldResync := ep.shouldResync(nm) if !shouldResync { continue } log.Printf("netmap change detected, ensuring firewall config is up to date...") } - if err := ep.sync(ctx, n); err != nil { + if err := ep.sync(ctx, nm); err != nil { return fmt.Errorf("error syncing egress service config: %w", err) } } @@ -135,7 +136,7 @@ type egressProxyRunOpts struct { kc kubeclient.Client tsClient *local.Client stateSecret string - netmapChan chan ipn.Notify + netmapChan chan *netmap.NetworkMap podIPv4 string tailnetAddrs []netip.Prefix } @@ -164,7 +165,7 @@ func (ep *egressProxy) configure(opts egressProxyRunOpts) { // any firewall rules need to be updated. Currently using status in state Secret as a reference for what is the current // firewall configuration is good enough because - the status is keyed by the Pod IP - we crash the Pod on errors such // as failed firewall update -func (ep *egressProxy) sync(ctx context.Context, n ipn.Notify) error { +func (ep *egressProxy) sync(ctx context.Context, nm *netmap.NetworkMap) error { cfgs, err := ep.getConfigs() if err != nil { return fmt.Errorf("error retrieving egress service configs: %w", err) @@ -173,12 +174,12 @@ func (ep *egressProxy) sync(ctx context.Context, n ipn.Notify) error { if err != nil { return fmt.Errorf("error retrieving current egress proxy status: %w", err) } - newStatus, err := ep.syncEgressConfigs(cfgs, status, n) + newStatus, err := ep.syncEgressConfigs(cfgs, status, nm) if err != nil { return fmt.Errorf("error syncing egress service configs: %w", err) } if !servicesStatusIsEqual(newStatus, status) { - if err := ep.setStatus(ctx, newStatus, n); err != nil { + if err := ep.setStatus(ctx, newStatus, nm); err != nil { return fmt.Errorf("error setting egress proxy status: %w", err) } } @@ -187,14 +188,14 @@ func (ep *egressProxy) sync(ctx context.Context, n ipn.Notify) error { // addrsHaveChanged returns true if the provided netmap update contains tailnet address change for this proxy node. // Netmap must not be nil. -func (ep *egressProxy) addrsHaveChanged(n ipn.Notify) bool { - return !reflect.DeepEqual(ep.tailnetAddrs, n.NetMap.SelfNode.Addresses()) +func (ep *egressProxy) addrsHaveChanged(nm *netmap.NetworkMap) bool { + return !reflect.DeepEqual(ep.tailnetAddrs, nm.SelfNode.Addresses()) } // syncEgressConfigs adds and deletes firewall rules to match the desired // configuration. It uses the provided status to determine what is currently // applied and updates the status after a successful sync. -func (ep *egressProxy) syncEgressConfigs(cfgs *egressservices.Configs, status *egressservices.Status, n ipn.Notify) (*egressservices.Status, error) { +func (ep *egressProxy) syncEgressConfigs(cfgs egressservices.Configs, status *egressservices.Status, nm *netmap.NetworkMap) (*egressservices.Status, error) { if !(wantsServicesConfigured(cfgs) || hasServicesConfigured(status)) { return nil, nil } @@ -212,8 +213,8 @@ func (ep *egressProxy) syncEgressConfigs(cfgs *egressservices.Configs, status *e // Add new services, update rules for any that have changed. rulesPerSvcToAdd := make(map[string][]rule, 0) rulesPerSvcToDelete := make(map[string][]rule, 0) - for svcName, cfg := range *cfgs { - tailnetTargetIPs, err := ep.tailnetTargetIPsForSvc(cfg, n) + for svcName, cfg := range cfgs { + tailnetTargetIPs, err := ep.tailnetTargetIPsForSvc(cfg, nm) if err != nil { return nil, fmt.Errorf("error determining tailnet target IPs: %w", err) } @@ -228,12 +229,12 @@ func (ep *egressProxy) syncEgressConfigs(cfgs *egressservices.Configs, status *e if len(rulesToDelete) != 0 { mak.Set(&rulesPerSvcToDelete, svcName, rulesToDelete) } - if len(rulesToAdd) != 0 || ep.addrsHaveChanged(n) { + if len(rulesToAdd) != 0 || ep.addrsHaveChanged(nm) { // For each tailnet target, set up SNAT from the local tailnet device address of the matching // family. for _, t := range tailnetTargetIPs { var local netip.Addr - for _, pfx := range n.NetMap.SelfNode.Addresses().All() { + for _, pfx := range nm.SelfNode.Addresses().All() { if !pfx.IsSingleIP() { continue } @@ -249,6 +250,9 @@ func (ep *egressProxy) syncEgressConfigs(cfgs *egressservices.Configs, status *e if err := ep.nfr.EnsureSNATForDst(local, t); err != nil { return nil, fmt.Errorf("error setting up SNAT rule: %w", err) } + if err := ep.nfr.ClampMSSToPMTU(tailscaleTunInterface, t); err != nil { + return nil, fmt.Errorf("error clamping MSS to PMTU: %w", err) + } } } // Update the status. Status will be written back to the state Secret by the caller. @@ -352,7 +356,7 @@ func updatesForCfg(svcName string, cfg egressservices.Config, status *egressserv // deleteUnneccessaryServices ensure that any services found on status, but not // present in config are deleted. -func (ep *egressProxy) deleteUnnecessaryServices(cfgs *egressservices.Configs, status *egressservices.Status) error { +func (ep *egressProxy) deleteUnnecessaryServices(cfgs egressservices.Configs, status *egressservices.Status) error { if !hasServicesConfigured(status) { return nil } @@ -367,7 +371,7 @@ func (ep *egressProxy) deleteUnnecessaryServices(cfgs *egressservices.Configs, s } for svcName, svc := range status.Services { - if _, ok := (*cfgs)[svcName]; !ok { + if _, ok := cfgs[svcName]; !ok { log.Printf("service %s is no longer required, deleting", svcName) if err := ensureServiceDeleted(svcName, svc, ep.nfr); err != nil { return fmt.Errorf("error deleting service %s: %w", svcName, err) @@ -379,7 +383,7 @@ func (ep *egressProxy) deleteUnnecessaryServices(cfgs *egressservices.Configs, s } // getConfigs gets the mounted egress service configuration. -func (ep *egressProxy) getConfigs() (*egressservices.Configs, error) { +func (ep *egressProxy) getConfigs() (egressservices.Configs, error) { svcsCfg := filepath.Join(ep.cfgPath, egressservices.KeyEgressServices) j, err := os.ReadFile(svcsCfg) if os.IsNotExist(err) { @@ -391,7 +395,7 @@ func (ep *egressProxy) getConfigs() (*egressservices.Configs, error) { if len(j) == 0 || string(j) == "" { return nil, nil } - cfg := &egressservices.Configs{} + cfg := egressservices.Configs{} if err := json.Unmarshal(j, &cfg); err != nil { return nil, err } @@ -423,7 +427,7 @@ func (ep *egressProxy) getStatus(ctx context.Context) (*egressservices.Status, e // setStatus writes egress proxy's currently configured firewall to the state // Secret and updates proxy's tailnet addresses. -func (ep *egressProxy) setStatus(ctx context.Context, status *egressservices.Status, n ipn.Notify) error { +func (ep *egressProxy) setStatus(ctx context.Context, status *egressservices.Status, nm *netmap.NetworkMap) error { // Pod IP is used to determine if a stored status applies to THIS proxy Pod. if status == nil { status = &egressservices.Status{} @@ -446,7 +450,7 @@ func (ep *egressProxy) setStatus(ctx context.Context, status *egressservices.Sta if err := ep.kc.JSONPatchResource(ctx, ep.stateSecret, kubeclient.TypeSecrets, []kubeclient.JSONPatch{patch}); err != nil { return fmt.Errorf("error patching state Secret: %w", err) } - ep.tailnetAddrs = n.NetMap.SelfNode.Addresses().AsSlice() + ep.tailnetAddrs = nm.SelfNode.Addresses().AsSlice() return nil } @@ -456,7 +460,7 @@ func (ep *egressProxy) setStatus(ctx context.Context, status *egressservices.Sta // FQDN, resolve the FQDN and return the resolved IPs. It checks if the // netfilter runner supports IPv6 NAT and skips any IPv6 addresses if it // doesn't. -func (ep *egressProxy) tailnetTargetIPsForSvc(svc egressservices.Config, n ipn.Notify) (addrs []netip.Addr, err error) { +func (ep *egressProxy) tailnetTargetIPsForSvc(svc egressservices.Config, nm *netmap.NetworkMap) (addrs []netip.Addr, err error) { if svc.TailnetTarget.IP != "" { addr, err := netip.ParseAddr(svc.TailnetTarget.IP) if err != nil { @@ -472,11 +476,11 @@ func (ep *egressProxy) tailnetTargetIPsForSvc(svc egressservices.Config, n ipn.N if svc.TailnetTarget.FQDN == "" { return nil, errors.New("unexpected egress service config- neither tailnet target IP nor FQDN is set") } - if n.NetMap == nil { + if nm == nil { log.Printf("netmap is not available, unable to determine backend addresses for %s", svc.TailnetTarget.FQDN) return addrs, nil } - egressAddrs, err := resolveTailnetFQDN(n.NetMap, svc.TailnetTarget.FQDN) + egressAddrs, err := resolveTailnetFQDN(nm, svc.TailnetTarget.FQDN) if err != nil { log.Printf("error fetching backend addresses for %q: %v", svc.TailnetTarget.FQDN, err) return addrs, nil @@ -502,22 +506,22 @@ func (ep *egressProxy) tailnetTargetIPsForSvc(svc egressservices.Config, n ipn.N // shouldResync parses netmap update and returns true if the update contains // changes for which the egress proxy's firewall should be reconfigured. -func (ep *egressProxy) shouldResync(n ipn.Notify) bool { - if n.NetMap == nil { +func (ep *egressProxy) shouldResync(nm *netmap.NetworkMap) bool { + if nm == nil { return false } // If proxy's tailnet addresses have changed, resync. - if !reflect.DeepEqual(n.NetMap.SelfNode.Addresses().AsSlice(), ep.tailnetAddrs) { + if !reflect.DeepEqual(nm.SelfNode.Addresses().AsSlice(), ep.tailnetAddrs) { log.Printf("node addresses have changed, trigger egress config resync") - ep.tailnetAddrs = n.NetMap.SelfNode.Addresses().AsSlice() + ep.tailnetAddrs = nm.SelfNode.Addresses().AsSlice() return true } // If the IPs for any of the egress services configured via FQDN have // changed, resync. for fqdn, ips := range ep.targetFQDNs { - for _, nn := range n.NetMap.Peers { + for _, nn := range nm.Peers { if equalFQDNs(nn.Name(), fqdn) { if !reflect.DeepEqual(ips, nn.Addresses().AsSlice()) { log.Printf("backend addresses for egress target %q have changed old IPs %v, new IPs %v trigger egress config resync", nn.Name(), ips, nn.Addresses().AsSlice()) @@ -602,8 +606,8 @@ type rule struct { protocol string } -func wantsServicesConfigured(cfgs *egressservices.Configs) bool { - return cfgs != nil && len(*cfgs) != 0 +func wantsServicesConfigured(cfgs egressservices.Configs) bool { + return cfgs != nil && len(cfgs) != 0 } func hasServicesConfigured(status *egressservices.Status) bool { @@ -657,13 +661,13 @@ func (ep *egressProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { // would normally be this Pod. When this Pod is being deleted, the operator should have removed it from the Service // backends and eventually kube proxy routing rules should be updated to no longer route traffic for the Service to this // Pod. -func (ep *egressProxy) waitTillSafeToShutdown(ctx context.Context, cfgs *egressservices.Configs, hp int) { - if cfgs == nil || len(*cfgs) == 0 { // avoid sleeping if no services are configured +func (ep *egressProxy) waitTillSafeToShutdown(ctx context.Context, cfgs egressservices.Configs, hp int) { + if cfgs == nil || len(cfgs) == 0 { // avoid sleeping if no services are configured return } log.Printf("Ensuring that cluster traffic for egress targets is no longer routed via this Pod...") var wg sync.WaitGroup - for s, cfg := range *cfgs { + for s, cfg := range cfgs { hep := cfg.HealthCheckEndpoint if hep == "" { log.Printf("Tailnet target %q does not have a cluster healthcheck specified, unable to verify if cluster traffic for the target is still routed via this Pod", s) diff --git a/cmd/containerboot/egressservices_test.go b/cmd/containerboot/egressservices_test.go index 0d8504bdad7fd..b30765f19425a 100644 --- a/cmd/containerboot/egressservices_test.go +++ b/cmd/containerboot/egressservices_test.go @@ -255,13 +255,13 @@ func TestWaitTillSafeToShutdown(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - cfgs := &egressservices.Configs{} + cfgs := egressservices.Configs{} switches := make(map[string]int) for svc, callsToSwitch := range tt.services { endpoint := fmt.Sprintf("http://%s.local", svc) if tt.healthCheckSet { - (*cfgs)[svc] = egressservices.Config{ + cfgs[svc] = egressservices.Config{ HealthCheckEndpoint: endpoint, } } diff --git a/cmd/containerboot/forwarding.go b/cmd/containerboot/forwarding.go index 0ec9c36c0bd30..6d90fbaaa9723 100644 --- a/cmd/containerboot/forwarding.go +++ b/cmd/containerboot/forwarding.go @@ -51,7 +51,7 @@ func ensureIPForwarding(root, clusterProxyTargetIP, tailnetTargetIP, tailnetTarg v4Forwarding = true } if routes != nil && *routes != "" { - for _, route := range strings.Split(*routes, ",") { + for route := range strings.SplitSeq(*routes, ",") { cidr, err := netip.ParsePrefix(route) if err != nil { return fmt.Errorf("invalid subnet route: %v", err) diff --git a/cmd/containerboot/ingressservices.go b/cmd/containerboot/ingressservices.go index d76bf86e0b8ec..d8ad017170379 100644 --- a/cmd/containerboot/ingressservices.go +++ b/cmd/containerboot/ingressservices.go @@ -265,7 +265,13 @@ func ensureIngressRulesAdded(cfgs map[string]ingressservices.Config, nfr linuxfw func addDNATRuleForSvc(nfr linuxfw.NetfilterRunner, serviceName string, tsIP, clusterIP netip.Addr) error { log.Printf("adding DNAT rule for Tailscale Service %s with IP %s to Kubernetes Service IP %s", serviceName, tsIP, clusterIP) - return nfr.EnsureDNATRuleForSvc(serviceName, tsIP, clusterIP) + if err := nfr.EnsureDNATRuleForSvc(serviceName, tsIP, clusterIP); err != nil { + return err + } + if err := nfr.ClampMSSToPMTU(tailscaleTunInterface, clusterIP); err != nil { + return fmt.Errorf("error clamping MSS to PMTU: %w", err) + } + return nil } // ensureIngressRulesDeleted takes a map of Tailscale Services and rules and ensures that the firewall rules are deleted. diff --git a/cmd/containerboot/ingressservices_test.go b/cmd/containerboot/ingressservices_test.go index 46330103e343b..1643bb11c069e 100644 --- a/cmd/containerboot/ingressservices_test.go +++ b/cmd/containerboot/ingressservices_test.go @@ -7,6 +7,7 @@ package main import ( "net/netip" + "slices" "testing" "tailscale.com/kube/ingressservices" @@ -22,6 +23,7 @@ func TestSyncIngressConfigs(t *testing.T) { TailscaleServiceIP netip.Addr ClusterIP netip.Addr } + wantClampedAddrs []netip.Addr // cluster IPs that should have MSS clamping applied }{ { name: "add_new_rules_when_no_existing_config", @@ -35,6 +37,7 @@ func TestSyncIngressConfigs(t *testing.T) { }{ "svc:foo": makeWantService("100.64.0.1", "10.0.0.1"), }, + wantClampedAddrs: []netip.Addr{netip.MustParseAddr("10.0.0.1")}, }, { name: "add_multiple_services", @@ -52,6 +55,11 @@ func TestSyncIngressConfigs(t *testing.T) { "svc:bar": makeWantService("100.64.0.2", "10.0.0.2"), "svc:baz": makeWantService("100.64.0.3", "10.0.0.3"), }, + wantClampedAddrs: []netip.Addr{ + netip.MustParseAddr("10.0.0.1"), + netip.MustParseAddr("10.0.0.2"), + netip.MustParseAddr("10.0.0.3"), + }, }, { name: "add_both_ipv4_and_ipv6_rules", @@ -65,6 +73,10 @@ func TestSyncIngressConfigs(t *testing.T) { }{ "svc:foo": makeWantService("2001:db8::1", "2001:db8::2"), }, + wantClampedAddrs: []netip.Addr{ + netip.MustParseAddr("10.0.0.1"), + netip.MustParseAddr("2001:db8::2"), + }, }, { name: "add_ipv6_only_rules", @@ -78,6 +90,7 @@ func TestSyncIngressConfigs(t *testing.T) { }{ "svc:ipv6": makeWantService("2001:db8::10", "2001:db8::20"), }, + wantClampedAddrs: []netip.Addr{netip.MustParseAddr("2001:db8::20")}, }, { name: "delete_all_rules_when_config_removed", @@ -94,6 +107,7 @@ func TestSyncIngressConfigs(t *testing.T) { TailscaleServiceIP netip.Addr ClusterIP netip.Addr }{}, + wantClampedAddrs: nil, // no rules added, no clamping }, { name: "add_remove_modify", @@ -117,6 +131,10 @@ func TestSyncIngressConfigs(t *testing.T) { "svc:foo": makeWantService("100.64.0.1", "10.0.0.2"), "svc:new": makeWantService("100.64.0.4", "10.0.0.4"), }, + wantClampedAddrs: []netip.Addr{ + netip.MustParseAddr("10.0.0.2"), + netip.MustParseAddr("10.0.0.4"), + }, }, { name: "update_with_outdated_status", @@ -152,12 +170,17 @@ func TestSyncIngressConfigs(t *testing.T) { "svc:web-ipv6": makeWantService("2001:db8::10", "2001:db8::20"), "svc:api": makeWantService("100.64.0.20", "10.0.0.20"), }, + wantClampedAddrs: []netip.Addr{ + netip.MustParseAddr("10.0.0.10"), + netip.MustParseAddr("10.0.0.20"), + netip.MustParseAddr("2001:db8::20"), + }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - var nfr linuxfw.NetfilterRunner = linuxfw.NewFakeNetfilterRunner() + nfr := linuxfw.NewFakeNetfilterRunner() ep := &ingressProxy{ nfr: nfr, @@ -170,8 +193,7 @@ func TestSyncIngressConfigs(t *testing.T) { t.Fatalf("syncIngressConfigs failed: %v", err) } - fake := nfr.(*linuxfw.FakeNetfilterRunner) - gotServices := fake.GetServiceState() + gotServices := nfr.GetServiceState() if len(gotServices) != len(tt.wantServices) { t.Errorf("got %d services, want %d", len(gotServices), len(tt.wantServices)) } @@ -188,6 +210,20 @@ func TestSyncIngressConfigs(t *testing.T) { t.Errorf("service %s: got ClusterIP %v, want %v", svc, got.ClusterIP, want.ClusterIP) } } + + gotClamped := nfr.GetClampedAddrs() + slices.SortFunc(gotClamped, func(a, b netip.Addr) int { return a.Compare(b) }) + slices.SortFunc(tt.wantClampedAddrs, func(a, b netip.Addr) int { return a.Compare(b) }) + if len(gotClamped) != len(tt.wantClampedAddrs) { + t.Errorf("ClampMSSToPMTU: got %v, want %v", gotClamped, tt.wantClampedAddrs) + } else { + for i := range gotClamped { + if gotClamped[i] != tt.wantClampedAddrs[i] { + t.Errorf("ClampMSSToPMTU: got %v, want %v", gotClamped, tt.wantClampedAddrs) + break + } + } + } }) } } diff --git a/cmd/containerboot/settings.go b/cmd/containerboot/settings.go index e6147717bb39a..f695f2e5db5f4 100644 --- a/cmd/containerboot/settings.go +++ b/cmd/containerboot/settings.go @@ -89,24 +89,24 @@ type settings struct { func configFromEnv() (*settings, error) { cfg := &settings{ - AuthKey: defaultEnvs([]string{"TS_AUTHKEY", "TS_AUTH_KEY"}, ""), - ClientID: defaultEnv("TS_CLIENT_ID", ""), - ClientSecret: defaultEnv("TS_CLIENT_SECRET", ""), - IDToken: defaultEnv("TS_ID_TOKEN", ""), - Audience: defaultEnv("TS_AUDIENCE", ""), - Hostname: defaultEnv("TS_HOSTNAME", ""), - Routes: defaultEnvStringPointer("TS_ROUTES"), - ServeConfigPath: defaultEnv("TS_SERVE_CONFIG", ""), - ProxyTargetIP: defaultEnv("TS_DEST_IP", ""), - ProxyTargetDNSName: defaultEnv("TS_EXPERIMENTAL_DEST_DNS_NAME", ""), - TailnetTargetIP: defaultEnv("TS_TAILNET_TARGET_IP", ""), - TailnetTargetFQDN: defaultEnv("TS_TAILNET_TARGET_FQDN", ""), - DaemonExtraArgs: defaultEnv("TS_TAILSCALED_EXTRA_ARGS", ""), - ExtraArgs: defaultEnv("TS_EXTRA_ARGS", ""), - InKubernetes: os.Getenv("KUBERNETES_SERVICE_HOST") != "", - UserspaceMode: defaultBool("TS_USERSPACE", true), - StateDir: defaultEnv("TS_STATE_DIR", ""), - AcceptDNS: defaultEnvBoolPointer("TS_ACCEPT_DNS"), + AuthKey: defaultEnvs([]string{"TS_AUTHKEY", "TS_AUTH_KEY"}, ""), + ClientID: defaultEnv("TS_CLIENT_ID", ""), + ClientSecret: defaultEnv("TS_CLIENT_SECRET", ""), + IDToken: defaultEnv("TS_ID_TOKEN", ""), + Audience: defaultEnv("TS_AUDIENCE", ""), + Hostname: defaultEnv("TS_HOSTNAME", ""), + Routes: defaultEnvStringPointer("TS_ROUTES"), + ServeConfigPath: defaultEnv("TS_SERVE_CONFIG", ""), + ProxyTargetIP: defaultEnv("TS_DEST_IP", ""), + ProxyTargetDNSName: defaultEnv("TS_EXPERIMENTAL_DEST_DNS_NAME", ""), + TailnetTargetIP: defaultEnv("TS_TAILNET_TARGET_IP", ""), + TailnetTargetFQDN: defaultEnv("TS_TAILNET_TARGET_FQDN", ""), + DaemonExtraArgs: defaultEnv("TS_TAILSCALED_EXTRA_ARGS", ""), + ExtraArgs: defaultEnv("TS_EXTRA_ARGS", ""), + InKubernetes: os.Getenv("KUBERNETES_SERVICE_HOST") != "", + UserspaceMode: defaultBool("TS_USERSPACE", true), + StateDir: defaultEnv("TS_STATE_DIR", ""), + AcceptDNS: defaultEnvBoolPointer("TS_ACCEPT_DNS"), KubeSecret: func() string { if os.Getenv("KUBERNETES_SERVICE_HOST") != "" { return defaultEnv("TS_KUBE_SECRET", "tailscale") diff --git a/cmd/derper/bootstrap_dns_test.go b/cmd/derper/bootstrap_dns_test.go index 5b765f6d37b5f..2055b97511940 100644 --- a/cmd/derper/bootstrap_dns_test.go +++ b/cmd/derper/bootstrap_dns_test.go @@ -41,8 +41,28 @@ func (b *bitbucketResponseWriter) Write(p []byte) (int, error) { return len(p), func (b *bitbucketResponseWriter) WriteHeader(statusCode int) {} +// setDNSCache sets the published DNS cache for tests. +func setDNSCache(tb testing.TB, m *dnsEntryMap) { + tb.Helper() + j, err := json.Marshal(m.IPs) + if err != nil { + tb.Fatal(err) + } + tstest.AssertNotParallel(tb) + dnsCache.Store(m) + dnsCacheBytes.Store(j) + tb.Cleanup(func() { + dnsCache.Store(nil) + dnsCacheBytes.Store(nil) + }) +} + func getBootstrapDNS(t *testing.T, q string) map[string][]net.IP { t.Helper() + tstest.AssertNotParallel(t) + if dnsCache.Load() == nil { + t.Fatal("dnsCache not initialized; call setDNSCache before getBootstrapDNS") + } req, _ := http.NewRequest("GET", "https://localhost/bootstrap-dns?q="+url.QueryEscape(q), nil) w := httptest.NewRecorder() handleBootstrapDNS(w, req) @@ -100,7 +120,8 @@ func TestUnpublishedDNS(t *testing.T) { } } -func resetMetrics() { +func resetMetrics(tb testing.TB) { + tstest.AssertNotParallel(tb) publishedDNSHits.Set(0) publishedDNSMisses.Set(0) unpublishedDNSHits.Set(0) @@ -114,8 +135,7 @@ func TestUnpublishedDNSEmptyList(t *testing.T) { pub := &dnsEntryMap{ IPs: map[string][]net.IP{"tailscale.com": {net.IPv4(10, 10, 10, 10)}}, } - dnsCache.Store(pub) - dnsCacheBytes.Store([]byte(`{"tailscale.com":["10.10.10.10"]}`)) + setDNSCache(t, pub) unpublishedDNSCache.Store(&dnsEntryMap{ IPs: map[string][]net.IP{ @@ -131,7 +151,7 @@ func TestUnpublishedDNSEmptyList(t *testing.T) { t.Run("CacheMiss", func(t *testing.T) { // One domain in map but empty, one not in map at all for _, q := range []string{"log.tailscale.com", "login.tailscale.com"} { - resetMetrics() + resetMetrics(t) ips := getBootstrapDNS(t, q) // Expected our public map to be returned on a cache miss @@ -149,7 +169,7 @@ func TestUnpublishedDNSEmptyList(t *testing.T) { // Verify that we do get a valid response and metric. t.Run("CacheHit", func(t *testing.T) { - resetMetrics() + resetMetrics(t) ips := getBootstrapDNS(t, "controlplane.tailscale.com") want := map[string][]net.IP{"controlplane.tailscale.com": {net.IPv4(1, 2, 3, 4)}} if !reflect.DeepEqual(ips, want) { @@ -166,8 +186,10 @@ func TestUnpublishedDNSEmptyList(t *testing.T) { } func TestLookupMetric(t *testing.T) { + setDNSCache(t, &dnsEntryMap{}) + d := []string{"a.io", "b.io", "c.io", "d.io", "e.io", "e.io", "e.io", "a.io"} - resetMetrics() + resetMetrics(t) for _, q := range d { _ = getBootstrapDNS(t, q) } diff --git a/cmd/derper/depaware.txt b/cmd/derper/depaware.txt index a0eb4a29e259c..ec59c7264f501 100644 --- a/cmd/derper/depaware.txt +++ b/cmd/derper/depaware.txt @@ -138,11 +138,11 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa tailscale.com/types/opt from tailscale.com/envknob+ tailscale.com/types/persist from tailscale.com/ipn+ tailscale.com/types/preftype from tailscale.com/ipn - tailscale.com/types/ptr from tailscale.com/hostinfo+ tailscale.com/types/result from tailscale.com/util/lineiter tailscale.com/types/structs from tailscale.com/ipn+ tailscale.com/types/tkatype from tailscale.com/client/local+ tailscale.com/types/views from tailscale.com/ipn+ + tailscale.com/util/bufiox from tailscale.com/derp/derpserver+ tailscale.com/util/cibuild from tailscale.com/health+ tailscale.com/util/clientmetric from tailscale.com/net/netmon+ tailscale.com/util/cloudenv from tailscale.com/hostinfo+ diff --git a/cmd/derper/derper.go b/cmd/derper/derper.go index 87f9a0bc084e4..745d887f8bc06 100644 --- a/cmd/derper/derper.go +++ b/cmd/derper/derper.go @@ -87,6 +87,8 @@ var ( acceptConnLimit = flag.Float64("accept-connection-limit", math.Inf(+1), "rate limit for accepting new connection") acceptConnBurst = flag.Int("accept-connection-burst", math.MaxInt, "burst limit for accepting new connection") + rateConfigPath = flag.String("rate-config", "", "if non-empty, path to JSON rate limit config file. Rate limiting is experimental and subject to change. Configuration is reloaded on SIGHUP.") + // tcpKeepAlive is intentionally long, to reduce battery cost. There is an L7 keepalive on a higher frequency schedule. tcpKeepAlive = flag.Duration("tcp-keepalive-time", 10*time.Minute, "TCP keepalive time") // tcpUserTimeout is intentionally short, so that hung connections are cleaned up promptly. DERPs should be nearby users. @@ -192,6 +194,12 @@ func main() { s.SetVerifyClientURL(*verifyClientURL) s.SetVerifyClientURLFailOpen(*verifyFailOpen) s.SetTCPWriteTimeout(*tcpWriteTimeout) + if *rateConfigPath != "" { + if err := s.LoadAndApplyRateConfig(*rateConfigPath); err != nil { + log.Fatalf("derper: loading rate config: %v", err) + } + go watchRateConfig(ctx, s, *rateConfigPath) + } var meshKey string if *dev { @@ -244,7 +252,7 @@ func main() { if err := startMesh(s); err != nil { log.Fatalf("startMesh: %v", err) } - expvar.Publish("derp", s.ExpVar()) + expvar.Publish("derp", s.ExpVar(*rateConfigPath != "")) handleHome, ok := getHomeHandler(*flagHome) if !ok { @@ -426,6 +434,27 @@ func main() { } } +// watchRateConfig listens for SIGHUP signals and reloads the rate config +// file on each signal, applying it to the server. It returns when ctx is done. +func watchRateConfig(ctx context.Context, s *derpserver.Server, path string) { + sighup := make(chan os.Signal, 1) + signal.Notify(sighup, syscall.SIGHUP) + defer signal.Stop(sighup) + for { + select { + case <-ctx.Done(): + return + case <-sighup: + log.Printf("derper: received SIGHUP, reloading rate config from %s", path) + if err := s.LoadAndApplyRateConfig(path); err != nil { + log.Printf("derper: rate config reload failed: %v", err) + continue + } + log.Printf("derper: rate config reloaded successfully") + } + } +} + var validProdHostname = regexp.MustCompile(`^derp([^.]*)\.tailscale\.com\.?$`) func prodAutocertHostPolicy(_ context.Context, host string) error { diff --git a/cmd/derper/derper_test.go b/cmd/derper/derper_test.go index 0a2fd8787d61d..fc1ebd6930dd6 100644 --- a/cmd/derper/derper_test.go +++ b/cmd/derper/derper_test.go @@ -46,30 +46,30 @@ func TestNoContent(t *testing.T) { want string }{ { - name: "no challenge", + name: "no-challenge", }, { - name: "valid challenge", + name: "valid-challenge", input: "input", want: "response input", }, { - name: "valid challenge hostname", + name: "valid-challenge-hostname", input: "ts_derp99b.tailscale.com", want: "response ts_derp99b.tailscale.com", }, { - name: "invalid challenge", + name: "invalid-challenge", input: "foo\x00bar", want: "", }, { - name: "whitespace invalid challenge", + name: "whitespace-invalid-challenge", input: "foo bar", want: "", }, { - name: "long challenge", + name: "long-challenge", input: strings.Repeat("x", 65), want: "", }, diff --git a/cmd/derper/mesh.go b/cmd/derper/mesh.go index 34ea7da856220..c07cfe969d9e3 100644 --- a/cmd/derper/mesh.go +++ b/cmd/derper/mesh.go @@ -25,7 +25,7 @@ func startMesh(s *derpserver.Server) error { if !s.HasMeshKey() { return errors.New("--mesh-with requires --mesh-psk-file") } - for _, hostTuple := range strings.Split(*meshWith, ",") { + for hostTuple := range strings.SplitSeq(*meshWith, ",") { if err := startMeshWithHost(s, hostTuple); err != nil { return err } diff --git a/cmd/gitops-pusher/gitops-pusher.go b/cmd/gitops-pusher/gitops-pusher.go index 11448e30da1aa..9ea115a1585e7 100644 --- a/cmd/gitops-pusher/gitops-pusher.go +++ b/cmd/gitops-pusher/gitops-pusher.go @@ -26,7 +26,7 @@ import ( "github.com/tailscale/hujson" "golang.org/x/oauth2/clientcredentials" tsclient "tailscale.com/client/tailscale" - _ "tailscale.com/feature/condregister/identityfederation" + _ "tailscale.com/feature/identityfederation" "tailscale.com/internal/client/tailscale" "tailscale.com/util/httpm" ) diff --git a/cmd/gitops-pusher/gitops-pusher_test.go b/cmd/gitops-pusher/gitops-pusher_test.go index bc339ae6a0b84..8d785e8cf793a 100644 --- a/cmd/gitops-pusher/gitops-pusher_test.go +++ b/cmd/gitops-pusher/gitops-pusher_test.go @@ -30,7 +30,7 @@ func TestEmbeddedTypeUnmarshal(t *testing.T) { }, } - t.Run("unmarshal gitops type from acl type", func(t *testing.T) { + t.Run("unmarshal-gitops-from-acl", func(t *testing.T) { b, _ := json.Marshal(aclTestErr) var e ACLGitopsTestError err := json.Unmarshal(b, &e) @@ -41,7 +41,7 @@ func TestEmbeddedTypeUnmarshal(t *testing.T) { t.Fatalf("user heading for 'ACLError' not found in gitops error: %v", e.Error()) } }) - t.Run("unmarshal acl type from gitops type", func(t *testing.T) { + t.Run("unmarshal-acl-from-gitops", func(t *testing.T) { b, _ := json.Marshal(gitopsErr) var e tailscale.ACLTestError err := json.Unmarshal(b, &e) diff --git a/cmd/hello/hello.go b/cmd/hello/hello.go index 710de49cd67a8..45eb7751c3790 100644 --- a/cmd/hello/hello.go +++ b/cmd/hello/hello.go @@ -5,212 +5,16 @@ package main // import "tailscale.com/cmd/hello" import ( - "context" - "crypto/tls" - _ "embed" - "encoding/json" - "errors" - "flag" - "html/template" "log" - "net/http" - "os" - "strings" - "time" - "tailscale.com/client/local" - "tailscale.com/client/tailscale/apitype" - "tailscale.com/tailcfg" + "tailscale.com/cmd/hello/helloserver" ) -var ( - httpAddr = flag.String("http", ":80", "address to run an HTTP server on, or empty for none") - httpsAddr = flag.String("https", ":443", "address to run an HTTPS server on, or empty for none") - testIP = flag.String("test-ip", "", "if non-empty, look up IP and exit before running a server") -) - -//go:embed hello.tmpl.html -var embeddedTemplate string - -var localClient local.Client - func main() { - flag.Parse() - if *testIP != "" { - res, err := localClient.WhoIs(context.Background(), *testIP) - if err != nil { - log.Fatal(err) - } - e := json.NewEncoder(os.Stdout) - e.SetIndent("", "\t") - e.Encode(res) - return + s := &helloserver.Server{ + HTTPAddr: ":80", + HTTPSAddr: ":443", } - if devMode() { - // Parse it optimistically - var err error - tmpl, err = template.New("home").Parse(embeddedTemplate) - if err != nil { - log.Printf("ignoring template error in dev mode: %v", err) - } - } else { - if embeddedTemplate == "" { - log.Fatalf("embeddedTemplate is empty; must be build with Go 1.16+") - } - tmpl = template.Must(template.New("home").Parse(embeddedTemplate)) - } - - http.HandleFunc("/", root) log.Printf("Starting hello server.") - - errc := make(chan error, 1) - if *httpAddr != "" { - log.Printf("running HTTP server on %s", *httpAddr) - go func() { - errc <- http.ListenAndServe(*httpAddr, nil) - }() - } - if *httpsAddr != "" { - log.Printf("running HTTPS server on %s", *httpsAddr) - go func() { - hs := &http.Server{ - Addr: *httpsAddr, - TLSConfig: &tls.Config{ - GetCertificate: func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { - switch hi.ServerName { - case "hello.ts.net": - return localClient.GetCertificate(hi) - case "hello.ipn.dev": - c, err := tls.LoadX509KeyPair( - "/etc/hello/hello.ipn.dev.crt", - "/etc/hello/hello.ipn.dev.key", - ) - if err != nil { - return nil, err - } - return &c, nil - } - return nil, errors.New("invalid SNI name") - }, - }, - IdleTimeout: 30 * time.Second, - ReadHeaderTimeout: 20 * time.Second, - MaxHeaderBytes: 10 << 10, - } - errc <- hs.ListenAndServeTLS("", "") - }() - } - log.Fatal(<-errc) -} - -func devMode() bool { return *httpsAddr == "" && *httpAddr != "" } - -func getTmpl() (*template.Template, error) { - if devMode() { - tmplData, err := os.ReadFile("hello.tmpl.html") - if os.IsNotExist(err) { - log.Printf("using baked-in template in dev mode; can't find hello.tmpl.html in current directory") - return tmpl, nil - } - return template.New("home").Parse(string(tmplData)) - } - return tmpl, nil -} - -// tmpl is the template used in prod mode. -// In dev mode it's only used if the template file doesn't exist on disk. -// It's initialized by main after flag parsing. -var tmpl *template.Template - -type tmplData struct { - DisplayName string // "Foo Barberson" - LoginName string // "foo@bar.com" - ProfilePicURL string // "https://..." - MachineName string // "imac5k" - MachineOS string // "Linux" - IP string // "100.2.3.4" -} - -func tailscaleIP(who *apitype.WhoIsResponse) string { - if who == nil { - return "" - } - vals, err := tailcfg.UnmarshalNodeCapJSON[string](who.Node.CapMap, tailcfg.NodeAttrNativeIPV4) - if err == nil && len(vals) > 0 { - return vals[0] - } - for _, nodeIP := range who.Node.Addresses { - if nodeIP.Addr().Is4() && nodeIP.IsSingleIP() { - return nodeIP.Addr().String() - } - } - for _, nodeIP := range who.Node.Addresses { - if nodeIP.IsSingleIP() { - return nodeIP.Addr().String() - } - } - return "" -} - -func root(w http.ResponseWriter, r *http.Request) { - if r.TLS == nil && *httpsAddr != "" { - host := r.Host - if strings.Contains(r.Host, "100.101.102.103") || - strings.Contains(r.Host, "hello.ipn.dev") { - host = "hello.ts.net" - } - http.Redirect(w, r, "https://"+host, http.StatusFound) - return - } - if r.RequestURI != "/" { - http.Redirect(w, r, "/", http.StatusFound) - return - } - if r.TLS != nil && *httpsAddr != "" && strings.Contains(r.Host, "hello.ipn.dev") { - http.Redirect(w, r, "https://hello.ts.net", http.StatusFound) - return - } - tmpl, err := getTmpl() - if err != nil { - w.Header().Set("Content-Type", "text/plain") - http.Error(w, "template error: "+err.Error(), 500) - return - } - - who, err := localClient.WhoIs(r.Context(), r.RemoteAddr) - var data tmplData - if err != nil { - if devMode() { - log.Printf("warning: using fake data in dev mode due to whois lookup error: %v", err) - data = tmplData{ - DisplayName: "Taily Scalerson", - LoginName: "taily@scaler.son", - ProfilePicURL: "https://placekitten.com/200/200", - MachineName: "scaled", - MachineOS: "Linux", - IP: "100.1.2.3", - } - } else { - log.Printf("whois(%q) error: %v", r.RemoteAddr, err) - http.Error(w, "Your Tailscale works, but we failed to look you up.", 500) - return - } - } else { - data = tmplData{ - DisplayName: who.UserProfile.DisplayName, - LoginName: who.UserProfile.LoginName, - ProfilePicURL: who.UserProfile.ProfilePicURL, - MachineName: firstLabel(who.Node.ComputedName), - MachineOS: who.Node.Hostinfo.OS(), - IP: tailscaleIP(who), - } - } - w.Header().Set("Content-Type", "text/html; charset=utf-8") - tmpl.Execute(w, data) -} - -// firstLabel s up until the first period, if any. -func firstLabel(s string) string { - s, _, _ = strings.Cut(s, ".") - return s + log.Fatal(s.Run()) } diff --git a/cmd/hello/hello.tmpl.html b/cmd/hello/helloserver/hello.tmpl.html similarity index 100% rename from cmd/hello/hello.tmpl.html rename to cmd/hello/helloserver/hello.tmpl.html diff --git a/cmd/hello/helloserver/helloserver.go b/cmd/hello/helloserver/helloserver.go new file mode 100644 index 0000000000000..8d5972b83c8c5 --- /dev/null +++ b/cmd/hello/helloserver/helloserver.go @@ -0,0 +1,146 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// Package helloserver implements the HTTP server behind hello.ts.net. +package helloserver + +import ( + "crypto/tls" + _ "embed" + "html/template" + "log" + "net/http" + "strings" + "time" + + "tailscale.com/client/local" + "tailscale.com/client/tailscale/apitype" + "tailscale.com/tailcfg" +) + +//go:embed hello.tmpl.html +var embeddedTemplate string + +var tmpl = template.Must(template.New("home").Parse(embeddedTemplate)) + +// Server is an HTTP server for hello.ts.net. +// +// The zero value is not valid; populate at least one of HTTPAddr or HTTPSAddr +// before calling Run. +type Server struct { + // HTTPAddr is the address to run an HTTP server on, or empty for none. + HTTPAddr string + + // HTTPSAddr is the address to run an HTTPS server on, or empty for none. + HTTPSAddr string + + // LocalClient is used to look up the identity of incoming requests and + // to obtain TLS certificates. If nil, the zero value of local.Client is + // used. + LocalClient *local.Client +} + +func (s *Server) localClient() *local.Client { + if s.LocalClient != nil { + return s.LocalClient + } + return &local.Client{} +} + +// Run starts the configured HTTP and HTTPS servers and blocks until one of +// them returns an error. +func (s *Server) Run() error { + errc := make(chan error, 1) + if s.HTTPAddr != "" { + log.Printf("running HTTP server on %s", s.HTTPAddr) + go func() { + errc <- http.ListenAndServe(s.HTTPAddr, s) + }() + } + if s.HTTPSAddr != "" { + log.Printf("running HTTPS server on %s", s.HTTPSAddr) + go func() { + hs := &http.Server{ + Addr: s.HTTPSAddr, + Handler: s, + TLSConfig: &tls.Config{ + GetCertificate: s.localClient().GetCertificate, + }, + IdleTimeout: 30 * time.Second, + ReadHeaderTimeout: 20 * time.Second, + MaxHeaderBytes: 10 << 10, + } + errc <- hs.ListenAndServeTLS("", "") + }() + } + return <-errc +} + +type tmplData struct { + DisplayName string // "Foo Barberson" + LoginName string // "foo@bar.com" + ProfilePicURL string // "https://..." + MachineName string // "imac5k" + MachineOS string // "Linux" + IP string // "100.2.3.4" +} + +func tailscaleIP(who *apitype.WhoIsResponse) string { + if who == nil { + return "" + } + vals, err := tailcfg.UnmarshalNodeCapJSON[string](who.Node.CapMap, tailcfg.NodeAttrNativeIPV4) + if err == nil && len(vals) > 0 { + return vals[0] + } + for _, nodeIP := range who.Node.Addresses { + if nodeIP.Addr().Is4() && nodeIP.IsSingleIP() { + return nodeIP.Addr().String() + } + } + for _, nodeIP := range who.Node.Addresses { + if nodeIP.IsSingleIP() { + return nodeIP.Addr().String() + } + } + return "" +} + +// ServeHTTP implements http.Handler. +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.TLS == nil && s.HTTPSAddr != "" { + host := r.Host + if strings.Contains(r.Host, "100.101.102.103") { + host = "hello.ts.net" + } + http.Redirect(w, r, "https://"+host, http.StatusFound) + return + } + if r.RequestURI != "/" { + http.Redirect(w, r, "/", http.StatusFound) + return + } + + who, err := s.localClient().WhoIs(r.Context(), r.RemoteAddr) + if err != nil { + log.Printf("whois(%q) error: %v", r.RemoteAddr, err) + http.Error(w, "Your Tailscale works, but we failed to look you up.", 500) + return + } + data := tmplData{ + DisplayName: who.UserProfile.DisplayName, + LoginName: who.UserProfile.LoginName, + ProfilePicURL: who.UserProfile.ProfilePicURL, + MachineName: firstLabel(who.Node.ComputedName), + MachineOS: who.Node.Hostinfo.OS(), + IP: tailscaleIP(who), + } + w.Header().Set("Content-Type", "text/html; charset=utf-8") + tmpl.Execute(w, data) +} + +// firstLabel returns s up until the first period, if any. +func firstLabel(s string) string { + s, _, _ = strings.Cut(s, ".") + return s +} diff --git a/cmd/k8s-nameserver/main_test.go b/cmd/k8s-nameserver/main_test.go index 0624800836675..b5cd8c907d522 100644 --- a/cmd/k8s-nameserver/main_test.go +++ b/cmd/k8s-nameserver/main_test.go @@ -24,7 +24,7 @@ func TestNameserver(t *testing.T) { wantResp *dns.Msg }{ { - name: "A record query, record exists", + name: "A-record-exists", ip4: map[dnsname.FQDN][]net.IP{dnsname.FQDN("foo.bar.com."): {{1, 2, 3, 4}}}, query: &dns.Msg{ Question: []dns.Question{{Name: "foo.bar.com", Qtype: dns.TypeA}}, @@ -46,7 +46,7 @@ func TestNameserver(t *testing.T) { }}, }, { - name: "A record query, record does not exist", + name: "A-record-not-exists", ip4: map[dnsname.FQDN][]net.IP{dnsname.FQDN("foo.bar.com."): {{1, 2, 3, 4}}}, query: &dns.Msg{ Question: []dns.Question{{Name: "baz.bar.com", Qtype: dns.TypeA}}, @@ -64,7 +64,7 @@ func TestNameserver(t *testing.T) { }}, }, { - name: "A record query, but the name is not a valid FQDN", + name: "A-record-invalid-FQDN", ip4: map[dnsname.FQDN][]net.IP{dnsname.FQDN("foo.bar.com."): {{1, 2, 3, 4}}}, query: &dns.Msg{ Question: []dns.Question{{Name: "foo..bar.com", Qtype: dns.TypeA}}, @@ -80,7 +80,7 @@ func TestNameserver(t *testing.T) { }}, }, { - name: "AAAA record query, A record exists", + name: "AAAA-query-A-record-exists", ip4: map[dnsname.FQDN][]net.IP{dnsname.FQDN("foo.bar.com."): {{1, 2, 3, 4}}}, query: &dns.Msg{ Question: []dns.Question{{Name: "foo.bar.com", Qtype: dns.TypeAAAA}}, @@ -97,7 +97,7 @@ func TestNameserver(t *testing.T) { }}, }, { - name: "AAAA record query, A record does not exist", + name: "AAAA-query-A-record-not-exists", ip4: map[dnsname.FQDN][]net.IP{dnsname.FQDN("foo.bar.com."): {{1, 2, 3, 4}}}, query: &dns.Msg{ Question: []dns.Question{{Name: "baz.bar.com", Qtype: dns.TypeAAAA}}, @@ -114,7 +114,7 @@ func TestNameserver(t *testing.T) { }}, }, { - name: "AAAA record query with IPv6 record", + name: "AAAA-query-ipv6-record", ip6: map[dnsname.FQDN][]net.IP{dnsname.FQDN("foo.bar.com."): {net.ParseIP("2001:db8::1")}}, query: &dns.Msg{ Question: []dns.Question{{Name: "foo.bar.com", Qtype: dns.TypeAAAA}}, @@ -136,7 +136,7 @@ func TestNameserver(t *testing.T) { }}, }, { - name: "Dual-stack: both A and AAAA records exist", + name: "dual-stack-A-and-AAAA", ip4: map[dnsname.FQDN][]net.IP{dnsname.FQDN("dual.bar.com."): {{10, 0, 0, 1}}}, ip6: map[dnsname.FQDN][]net.IP{dnsname.FQDN("dual.bar.com."): {net.ParseIP("2001:db8::1")}}, query: &dns.Msg{ @@ -157,7 +157,7 @@ func TestNameserver(t *testing.T) { }}, }, { - name: "CNAME record query", + name: "CNAME-query", ip4: map[dnsname.FQDN][]net.IP{dnsname.FQDN("foo.bar.com."): {{1, 2, 3, 4}}}, query: &dns.Msg{ Question: []dns.Question{{Name: "foo.bar.com", Qtype: dns.TypeCNAME}}, @@ -200,20 +200,20 @@ func TestResetRecords(t *testing.T) { wantsErr bool }{ { - name: "previously empty nameserver.ip4 gets set", + name: "previously-empty-nameserver-ip4-gets-set", config: []byte(`{"version": "v1alpha1", "ip4": {"foo.bar.com": ["1.2.3.4"]}}`), wantsIp4: map[dnsname.FQDN][]net.IP{"foo.bar.com.": {{1, 2, 3, 4}}}, wantsIp6: make(map[dnsname.FQDN][]net.IP), }, { - name: "nameserver.ip4 gets reset", + name: "nameserver-ip4-gets-reset", hasIp4: map[dnsname.FQDN][]net.IP{"baz.bar.com.": {{1, 1, 3, 3}}}, config: []byte(`{"version": "v1alpha1", "ip4": {"foo.bar.com": ["1.2.3.4"]}}`), wantsIp4: map[dnsname.FQDN][]net.IP{"foo.bar.com.": {{1, 2, 3, 4}}}, wantsIp6: make(map[dnsname.FQDN][]net.IP), }, { - name: "configuration with incompatible version", + name: "configuration-with-incompatible-version", hasIp4: map[dnsname.FQDN][]net.IP{"baz.bar.com.": {{1, 1, 3, 3}}}, config: []byte(`{"version": "v1beta1", "ip4": {"foo.bar.com": ["1.2.3.4"]}}`), wantsIp4: map[dnsname.FQDN][]net.IP{"baz.bar.com.": {{1, 1, 3, 3}}}, @@ -221,26 +221,26 @@ func TestResetRecords(t *testing.T) { wantsErr: true, }, { - name: "nameserver.ip4 gets reset to empty config when no configuration is provided", + name: "nameserver-ip4-gets-reset-to-empty-config-when-no-configuration-is-provided", hasIp4: map[dnsname.FQDN][]net.IP{"baz.bar.com.": {{1, 1, 3, 3}}}, wantsIp4: make(map[dnsname.FQDN][]net.IP), wantsIp6: make(map[dnsname.FQDN][]net.IP), }, { - name: "nameserver.ip4 gets reset to empty config when the provided configuration is empty", + name: "nameserver-ip4-gets-reset-to-empty-config-when-the-provided-configuration-is-empty", hasIp4: map[dnsname.FQDN][]net.IP{"baz.bar.com.": {{1, 1, 3, 3}}}, config: []byte(`{"version": "v1alpha1", "ip4": {}}`), wantsIp4: make(map[dnsname.FQDN][]net.IP), wantsIp6: make(map[dnsname.FQDN][]net.IP), }, { - name: "nameserver.ip6 gets set", + name: "nameserver-ip6-gets-set", config: []byte(`{"version": "v1alpha1", "ip6": {"foo.bar.com": ["2001:db8::1"]}}`), wantsIp4: make(map[dnsname.FQDN][]net.IP), wantsIp6: map[dnsname.FQDN][]net.IP{"foo.bar.com.": {net.ParseIP("2001:db8::1")}}, }, { - name: "dual-stack configuration", + name: "dual-stack-configuration", config: []byte(`{"version": "v1alpha1", "ip4": {"dual.bar.com": ["10.0.0.1"]}, "ip6": {"dual.bar.com": ["2001:db8::1"]}}`), wantsIp4: map[dnsname.FQDN][]net.IP{"dual.bar.com.": {{10, 0, 0, 1}}}, wantsIp6: map[dnsname.FQDN][]net.IP{"dual.bar.com.": {net.ParseIP("2001:db8::1")}}, diff --git a/cmd/k8s-operator/api-server-proxy.go b/cmd/k8s-operator/api-server-proxy.go index 492590c9fecd6..b8d87cf0aa38a 100644 --- a/cmd/k8s-operator/api-server-proxy.go +++ b/cmd/k8s-operator/api-server-proxy.go @@ -11,7 +11,6 @@ import ( "os" "tailscale.com/kube/kubetypes" - "tailscale.com/types/ptr" ) func parseAPIProxyMode() *kubetypes.APIServerProxyMode { @@ -23,18 +22,18 @@ func parseAPIProxyMode() *kubetypes.APIServerProxyMode { case haveAuthProxyEnv: var authProxyEnv = defaultBool("AUTH_PROXY", false) // deprecated if authProxyEnv { - return ptr.To(kubetypes.APIServerProxyModeAuth) + return new(kubetypes.APIServerProxyModeAuth) } return nil case haveAPIProxyEnv: var apiProxyEnv = defaultEnv("APISERVER_PROXY", "") // true, false or "noauth" switch apiProxyEnv { case "true": - return ptr.To(kubetypes.APIServerProxyModeAuth) + return new(kubetypes.APIServerProxyModeAuth) case "false", "": return nil case "noauth": - return ptr.To(kubetypes.APIServerProxyModeNoAuth) + return new(kubetypes.APIServerProxyModeNoAuth) default: panic(fmt.Sprintf("unknown APISERVER_PROXY value %q", apiProxyEnv)) } diff --git a/cmd/k8s-operator/connector_test.go b/cmd/k8s-operator/connector_test.go index 7866f3e002921..69e8e287d07d1 100644 --- a/cmd/k8s-operator/connector_test.go +++ b/cmd/k8s-operator/connector_test.go @@ -19,10 +19,11 @@ import ( "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/tools/record" "sigs.k8s.io/controller-runtime/pkg/client/fake" + tsapi "tailscale.com/k8s-operator/apis/v1alpha1" + "tailscale.com/k8s-operator/tsclient" "tailscale.com/kube/kubetypes" "tailscale.com/tstest" - "tailscale.com/types/ptr" "tailscale.com/util/mak" ) @@ -39,7 +40,7 @@ func TestConnector(t *testing.T) { APIVersion: "tailscale.com/v1alpha1", }, Spec: tsapi.ConnectorSpec{ - Replicas: ptr.To[int32](1), + Replicas: new(int32(1)), SubnetRouter: &tsapi.SubnetRouter{ AdvertiseRoutes: []tsapi.Route{"10.40.0.0/14"}, }, @@ -63,7 +64,7 @@ func TestConnector(t *testing.T) { recorder: record.NewFakeRecorder(10), ssr: &tailscaleSTSReconciler{ Client: fc, - tsClient: ft, + clients: tsclient.NewProvider(ft), defaultTags: []string{"tag:k8s"}, operatorNamespace: "operator-ns", proxyImage: "tailscale/tailscale", @@ -166,7 +167,7 @@ func TestConnector(t *testing.T) { APIVersion: "tailscale.io/v1alpha1", }, Spec: tsapi.ConnectorSpec{ - Replicas: ptr.To[int32](1), + Replicas: new(int32(1)), SubnetRouter: &tsapi.SubnetRouter{ AdvertiseRoutes: []tsapi.Route{"10.40.0.0/14"}, }, @@ -229,7 +230,7 @@ func TestConnectorWithProxyClass(t *testing.T) { APIVersion: "tailscale.io/v1alpha1", }, Spec: tsapi.ConnectorSpec{ - Replicas: ptr.To[int32](1), + Replicas: new(int32(1)), SubnetRouter: &tsapi.SubnetRouter{ AdvertiseRoutes: []tsapi.Route{"10.40.0.0/14"}, }, @@ -253,7 +254,7 @@ func TestConnectorWithProxyClass(t *testing.T) { clock: cl, ssr: &tailscaleSTSReconciler{ Client: fc, - tsClient: ft, + clients: tsclient.NewProvider(ft), defaultTags: []string{"tag:k8s"}, operatorNamespace: "operator-ns", proxyImage: "tailscale/tailscale", @@ -326,7 +327,7 @@ func TestConnectorWithAppConnector(t *testing.T) { APIVersion: "tailscale.io/v1alpha1", }, Spec: tsapi.ConnectorSpec{ - Replicas: ptr.To[int32](1), + Replicas: new(int32(1)), AppConnector: &tsapi.AppConnector{}, }, } @@ -347,7 +348,7 @@ func TestConnectorWithAppConnector(t *testing.T) { clock: cl, ssr: &tailscaleSTSReconciler{ Client: fc, - tsClient: ft, + clients: tsclient.NewProvider(ft), defaultTags: []string{"tag:k8s"}, operatorNamespace: "operator-ns", proxyImage: "tailscale/tailscale", @@ -425,7 +426,7 @@ func TestConnectorWithMultipleReplicas(t *testing.T) { APIVersion: "tailscale.io/v1alpha1", }, Spec: tsapi.ConnectorSpec{ - Replicas: ptr.To[int32](3), + Replicas: new(int32(3)), AppConnector: &tsapi.AppConnector{}, HostnamePrefix: "test-connector", }, @@ -447,7 +448,7 @@ func TestConnectorWithMultipleReplicas(t *testing.T) { clock: cl, ssr: &tailscaleSTSReconciler{ Client: fc, - tsClient: ft, + clients: tsclient.NewProvider(ft), defaultTags: []string{"tag:k8s"}, operatorNamespace: "operator-ns", proxyImage: "tailscale/tailscale", @@ -496,7 +497,7 @@ func TestConnectorWithMultipleReplicas(t *testing.T) { // 5. We'll scale the connector down by 1 replica and make sure its secret is cleaned up mustUpdate[tsapi.Connector](t, fc, "", "test", func(conn *tsapi.Connector) { - conn.Spec.Replicas = ptr.To[int32](2) + conn.Spec.Replicas = new(int32(2)) }) expectReconciled(t, cr, "", "test") names = findGenNames(t, fc, "", "test", "connector") diff --git a/cmd/k8s-operator/depaware.txt b/cmd/k8s-operator/depaware.txt index 8718127b6e75f..dd0d29a53b0d9 100644 --- a/cmd/k8s-operator/depaware.txt +++ b/cmd/k8s-operator/depaware.txt @@ -6,77 +6,6 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ W đŸ’Ŗ github.com/alexbrainman/sspi from github.com/alexbrainman/sspi/internal/common+ W github.com/alexbrainman/sspi/internal/common from github.com/alexbrainman/sspi/negotiate W đŸ’Ŗ github.com/alexbrainman/sspi/negotiate from tailscale.com/net/tshttpproxy - github.com/aws/aws-sdk-go-v2/aws from github.com/aws/aws-sdk-go-v2/aws/defaults+ - github.com/aws/aws-sdk-go-v2/aws/defaults from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/aws/middleware from github.com/aws/aws-sdk-go-v2/aws/retry+ - github.com/aws/aws-sdk-go-v2/aws/protocol/query from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/aws-sdk-go-v2/aws/protocol/restjson from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/aws/protocol/xml from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/aws-sdk-go-v2/aws/ratelimit from github.com/aws/aws-sdk-go-v2/aws/retry - github.com/aws/aws-sdk-go-v2/aws/retry from github.com/aws/aws-sdk-go-v2/credentials/endpointcreds/internal/client+ - github.com/aws/aws-sdk-go-v2/aws/signer/internal/v4 from github.com/aws/aws-sdk-go-v2/aws/signer/v4 - github.com/aws/aws-sdk-go-v2/aws/signer/v4 from github.com/aws/aws-sdk-go-v2/internal/auth/smithy+ - github.com/aws/aws-sdk-go-v2/aws/transport/http from github.com/aws/aws-sdk-go-v2/config+ - github.com/aws/aws-sdk-go-v2/config from tailscale.com/wif - github.com/aws/aws-sdk-go-v2/credentials from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/credentials/ec2rolecreds from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/credentials/endpointcreds from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/credentials/endpointcreds/internal/client from github.com/aws/aws-sdk-go-v2/credentials/endpointcreds - github.com/aws/aws-sdk-go-v2/credentials/processcreds from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/credentials/ssocreds from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/credentials/stscreds from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/feature/ec2/imds from github.com/aws/aws-sdk-go-v2/config+ - github.com/aws/aws-sdk-go-v2/feature/ec2/imds/internal/config from github.com/aws/aws-sdk-go-v2/feature/ec2/imds - github.com/aws/aws-sdk-go-v2/internal/auth from github.com/aws/aws-sdk-go-v2/aws/signer/v4+ - github.com/aws/aws-sdk-go-v2/internal/auth/smithy from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/internal/configsources from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/internal/context from github.com/aws/aws-sdk-go-v2/aws/retry+ - github.com/aws/aws-sdk-go-v2/internal/endpoints from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/internal/endpoints/awsrulesfn from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 from github.com/aws/aws-sdk-go-v2/service/sso/internal/endpoints+ - github.com/aws/aws-sdk-go-v2/internal/ini from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/internal/middleware from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/internal/rand from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/aws-sdk-go-v2/internal/sdk from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/aws-sdk-go-v2/internal/sdkio from github.com/aws/aws-sdk-go-v2/credentials/processcreds - github.com/aws/aws-sdk-go-v2/internal/shareddefaults from github.com/aws/aws-sdk-go-v2/config+ - github.com/aws/aws-sdk-go-v2/internal/strings from github.com/aws/aws-sdk-go-v2/aws/signer/internal/v4 - github.com/aws/aws-sdk-go-v2/internal/sync/singleflight from github.com/aws/aws-sdk-go-v2/aws - github.com/aws/aws-sdk-go-v2/internal/timeconv from github.com/aws/aws-sdk-go-v2/aws/retry - github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/aws-sdk-go-v2/service/sso from github.com/aws/aws-sdk-go-v2/config+ - github.com/aws/aws-sdk-go-v2/service/sso/internal/endpoints from github.com/aws/aws-sdk-go-v2/service/sso - github.com/aws/aws-sdk-go-v2/service/sso/types from github.com/aws/aws-sdk-go-v2/service/sso - github.com/aws/aws-sdk-go-v2/service/ssooidc from github.com/aws/aws-sdk-go-v2/config+ - github.com/aws/aws-sdk-go-v2/service/ssooidc/internal/endpoints from github.com/aws/aws-sdk-go-v2/service/ssooidc - github.com/aws/aws-sdk-go-v2/service/ssooidc/types from github.com/aws/aws-sdk-go-v2/service/ssooidc - github.com/aws/aws-sdk-go-v2/service/sts from github.com/aws/aws-sdk-go-v2/config+ - github.com/aws/aws-sdk-go-v2/service/sts/internal/endpoints from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/aws-sdk-go-v2/service/sts/types from github.com/aws/aws-sdk-go-v2/credentials/stscreds+ - github.com/aws/smithy-go from github.com/aws/aws-sdk-go-v2/aws/protocol/restjson+ - github.com/aws/smithy-go/auth from github.com/aws/aws-sdk-go-v2/internal/auth+ - github.com/aws/smithy-go/auth/bearer from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/smithy-go/context from github.com/aws/smithy-go/auth/bearer - github.com/aws/smithy-go/document from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/smithy-go/encoding from github.com/aws/smithy-go/encoding/json+ - github.com/aws/smithy-go/encoding/httpbinding from github.com/aws/aws-sdk-go-v2/aws/protocol/query+ - github.com/aws/smithy-go/encoding/json from github.com/aws/aws-sdk-go-v2/service/ssooidc - github.com/aws/smithy-go/encoding/xml from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/smithy-go/endpoints from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/smithy-go/endpoints/private/rulesfn from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/smithy-go/internal/sync/singleflight from github.com/aws/smithy-go/auth/bearer - github.com/aws/smithy-go/io from github.com/aws/aws-sdk-go-v2/feature/ec2/imds+ - github.com/aws/smithy-go/logging from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/smithy-go/metrics from github.com/aws/aws-sdk-go-v2/aws/retry+ - github.com/aws/smithy-go/middleware from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/smithy-go/private/requestcompression from github.com/aws/aws-sdk-go-v2/config - github.com/aws/smithy-go/ptr from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/smithy-go/rand from github.com/aws/aws-sdk-go-v2/aws/middleware - github.com/aws/smithy-go/time from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/smithy-go/tracing from github.com/aws/aws-sdk-go-v2/aws/middleware+ - github.com/aws/smithy-go/transport/http from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/smithy-go/transport/http/internal/io from github.com/aws/smithy-go/transport/http github.com/beorn7/perks/quantile from github.com/prometheus/client_golang/prometheus github.com/blang/semver/v4 from k8s.io/component-base/metrics đŸ’Ŗ github.com/cespare/xxhash/v2 from github.com/prometheus/client_golang/prometheus+ @@ -130,7 +59,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ github.com/google/gnostic-models/jsonschema from github.com/google/gnostic-models/compiler github.com/google/gnostic-models/openapiv2 from k8s.io/client-go/discovery+ github.com/google/gnostic-models/openapiv3 from k8s.io/kube-openapi/pkg/handler3+ - github.com/google/uuid from github.com/prometheus-community/pro-bing+ + github.com/google/uuid from k8s.io/apimachinery/pkg/util/uuid+ github.com/hdevalence/ed25519consensus from tailscale.com/tka github.com/huin/goupnp from github.com/huin/goupnp/dcps/internetgateway2+ github.com/huin/goupnp/dcps/internetgateway2 from tailscale.com/net/portmapper @@ -164,7 +93,6 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ github.com/pires/go-proxyproto from tailscale.com/ipn/ipnlocal+ github.com/pkg/errors from github.com/evanphx/json-patch/v5+ github.com/pmezard/go-difflib/difflib from k8s.io/apimachinery/pkg/util/diff - D github.com/prometheus-community/pro-bing from tailscale.com/wgengine/netstack github.com/prometheus/client_golang/internal/github.com/golang/gddo/httputil from github.com/prometheus/client_golang/prometheus/promhttp github.com/prometheus/client_golang/internal/github.com/golang/gddo/httputil/header from github.com/prometheus/client_golang/internal/github.com/golang/gddo/httputil đŸ’Ŗ github.com/prometheus/client_golang/prometheus from github.com/prometheus/client_golang/prometheus/collectors+ @@ -180,7 +108,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ LD github.com/prometheus/procfs/internal/util from github.com/prometheus/procfs L đŸ’Ŗ github.com/safchain/ethtool from tailscale.com/net/netkernelconf github.com/spf13/pflag from k8s.io/client-go/tools/clientcmd+ - W đŸ’Ŗ github.com/tailscale/certstore from tailscale.com/control/controlclient + DW đŸ’Ŗ github.com/tailscale/certstore from tailscale.com/control/controlclient W đŸ’Ŗ github.com/tailscale/go-winio from tailscale.com/safesocket W đŸ’Ŗ github.com/tailscale/go-winio/internal/fs from github.com/tailscale/go-winio W đŸ’Ŗ github.com/tailscale/go-winio/internal/socket from github.com/tailscale/go-winio @@ -784,8 +712,9 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/appc from tailscale.com/ipn/ipnlocal đŸ’Ŗ tailscale.com/atomicfile from tailscale.com/ipn+ tailscale.com/client/local from tailscale.com/client/tailscale+ - tailscale.com/client/tailscale from tailscale.com/cmd/k8s-operator+ + tailscale.com/client/tailscale from tailscale.com/internal/client/tailscale tailscale.com/client/tailscale/apitype from tailscale.com/client/tailscale+ + tailscale.com/client/tailscale/v2 from tailscale.com/cmd/k8s-operator+ tailscale.com/client/web from tailscale.com/ipn/ipnlocal tailscale.com/control/controlbase from tailscale.com/control/controlhttp+ tailscale.com/control/controlclient from tailscale.com/ipn/ipnlocal+ @@ -804,11 +733,9 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/feature/buildfeatures from tailscale.com/wgengine/magicsock+ tailscale.com/feature/c2n from tailscale.com/tsnet tailscale.com/feature/condlite/expvar from tailscale.com/wgengine/magicsock - tailscale.com/feature/condregister/identityfederation from tailscale.com/tsnet tailscale.com/feature/condregister/oauthkey from tailscale.com/tsnet tailscale.com/feature/condregister/portmapper from tailscale.com/tsnet tailscale.com/feature/condregister/useproxy from tailscale.com/tsnet - tailscale.com/feature/identityfederation from tailscale.com/feature/condregister/identityfederation tailscale.com/feature/oauthkey from tailscale.com/feature/condregister/oauthkey tailscale.com/feature/portmapper from tailscale.com/feature/condregister/portmapper tailscale.com/feature/syspolicy from tailscale.com/logpolicy @@ -816,11 +743,11 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/health from tailscale.com/control/controlclient+ tailscale.com/health/healthmsg from tailscale.com/ipn/ipnlocal tailscale.com/hostinfo from tailscale.com/client/web+ - tailscale.com/internal/client/tailscale from tailscale.com/cmd/k8s-operator+ + tailscale.com/internal/client/tailscale from tailscale.com/feature/oauthkey+ tailscale.com/ipn from tailscale.com/client/local+ tailscale.com/ipn/conffile from tailscale.com/ipn/ipnlocal+ đŸ’Ŗ tailscale.com/ipn/ipnauth from tailscale.com/ipn/ipnlocal+ - tailscale.com/ipn/ipnext from tailscale.com/ipn/ipnlocal + tailscale.com/ipn/ipnext from tailscale.com/ipn/ipnlocal+ tailscale.com/ipn/ipnlocal from tailscale.com/ipn/localapi+ tailscale.com/ipn/ipnlocal/netmapcache from tailscale.com/ipn/ipnlocal tailscale.com/ipn/ipnstate from tailscale.com/client/local+ @@ -839,6 +766,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/k8s-operator/sessionrecording/spdy from tailscale.com/k8s-operator/sessionrecording tailscale.com/k8s-operator/sessionrecording/tsrecorder from tailscale.com/k8s-operator/sessionrecording+ tailscale.com/k8s-operator/sessionrecording/ws from tailscale.com/k8s-operator/sessionrecording + tailscale.com/k8s-operator/tsclient from tailscale.com/cmd/k8s-operator+ tailscale.com/kube/egressservices from tailscale.com/cmd/k8s-operator tailscale.com/kube/ingressservices from tailscale.com/cmd/k8s-operator tailscale.com/kube/k8s-proxy/conf from tailscale.com/cmd/k8s-operator @@ -908,12 +836,13 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/tstime from tailscale.com/cmd/k8s-operator+ tailscale.com/tstime/mono from tailscale.com/net/tstun+ tailscale.com/tstime/rate from tailscale.com/wgengine/filter - tailscale.com/tsweb from tailscale.com/util/eventbus + tailscale.com/tsweb from tailscale.com/util/eventbus+ tailscale.com/tsweb/varz from tailscale.com/util/usermetric+ tailscale.com/types/appctype from tailscale.com/ipn/ipnlocal+ tailscale.com/types/bools from tailscale.com/tsnet+ tailscale.com/types/dnstype from tailscale.com/ipn/ipnlocal+ tailscale.com/types/empty from tailscale.com/ipn+ + tailscale.com/types/events from tailscale.com/control/controlclient+ tailscale.com/types/ipproto from tailscale.com/net/flowtrack+ tailscale.com/types/key from tailscale.com/client/local+ tailscale.com/types/lazy from tailscale.com/ipn/ipnlocal+ @@ -927,12 +856,12 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/types/opt from tailscale.com/client/tailscale+ tailscale.com/types/persist from tailscale.com/control/controlclient+ tailscale.com/types/preftype from tailscale.com/ipn+ - tailscale.com/types/ptr from tailscale.com/cmd/k8s-operator+ tailscale.com/types/result from tailscale.com/util/lineiter tailscale.com/types/structs from tailscale.com/control/controlclient+ tailscale.com/types/tkatype from tailscale.com/client/local+ tailscale.com/types/views from tailscale.com/appc+ tailscale.com/util/backoff from tailscale.com/cmd/k8s-operator+ + tailscale.com/util/bufiox from tailscale.com/types/key tailscale.com/util/checkchange from tailscale.com/ipn/ipnlocal+ tailscale.com/util/cibuild from tailscale.com/health+ tailscale.com/util/clientmetric from tailscale.com/cmd/k8s-operator+ @@ -997,14 +926,12 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/wgengine/wgcfg/nmcfg from tailscale.com/ipn/ipnlocal đŸ’Ŗ tailscale.com/wgengine/wgint from tailscale.com/wgengine+ tailscale.com/wgengine/wglog from tailscale.com/wgengine - tailscale.com/wif from tailscale.com/feature/identityfederation golang.org/x/crypto/argon2 from tailscale.com/tka golang.org/x/crypto/blake2b from golang.org/x/crypto/argon2+ golang.org/x/crypto/blake2s from github.com/tailscale/wireguard-go/device+ - LD golang.org/x/crypto/blowfish from golang.org/x/crypto/ssh/internal/bcrypt_pbkdf - golang.org/x/crypto/chacha20 from golang.org/x/crypto/ssh+ + golang.org/x/crypto/chacha20 from golang.org/x/crypto/chacha20poly1305 golang.org/x/crypto/chacha20poly1305 from github.com/tailscale/wireguard-go/device+ - golang.org/x/crypto/curve25519 from golang.org/x/crypto/ssh+ + golang.org/x/crypto/curve25519 from github.com/tailscale/wireguard-go/device+ golang.org/x/crypto/hkdf from tailscale.com/control/controlbase golang.org/x/crypto/internal/alias from golang.org/x/crypto/chacha20+ golang.org/x/crypto/internal/poly1305 from golang.org/x/crypto/chacha20poly1305+ @@ -1012,8 +939,6 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ golang.org/x/crypto/nacl/secretbox from golang.org/x/crypto/nacl/box golang.org/x/crypto/poly1305 from github.com/tailscale/wireguard-go/device golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+ - LD golang.org/x/crypto/ssh from tailscale.com/ipn/ipnlocal - LD golang.org/x/crypto/ssh/internal/bcrypt_pbkdf from golang.org/x/crypto/ssh golang.org/x/exp/constraints from tailscale.com/tsweb/varz+ golang.org/x/exp/maps from sigs.k8s.io/controller-runtime/pkg/cache+ golang.org/x/exp/slices from tailscale.com/cmd/k8s-operator+ @@ -1023,19 +948,20 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ golang.org/x/net/http/httpproxy from tailscale.com/net/tshttpproxy golang.org/x/net/http2 from k8s.io/apimachinery/pkg/util/net+ golang.org/x/net/http2/hpack from golang.org/x/net/http2+ - golang.org/x/net/icmp from github.com/prometheus-community/pro-bing+ + golang.org/x/net/icmp from tailscale.com/net/ping golang.org/x/net/idna from golang.org/x/net/http/httpguts+ golang.org/x/net/internal/httpcommon from golang.org/x/net/http2 + golang.org/x/net/internal/httpsfv from golang.org/x/net/http2 golang.org/x/net/internal/iana from golang.org/x/net/icmp+ - golang.org/x/net/internal/socket from golang.org/x/net/icmp+ + golang.org/x/net/internal/socket from golang.org/x/net/ipv4+ golang.org/x/net/internal/socks from golang.org/x/net/proxy - golang.org/x/net/ipv4 from github.com/prometheus-community/pro-bing+ - golang.org/x/net/ipv6 from github.com/prometheus-community/pro-bing+ + golang.org/x/net/ipv4 from github.com/tailscale/wireguard-go/conn+ + golang.org/x/net/ipv6 from github.com/tailscale/wireguard-go/conn+ golang.org/x/net/proxy from tailscale.com/net/netns D golang.org/x/net/route from tailscale.com/net/netmon+ golang.org/x/net/websocket from tailscale.com/k8s-operator/sessionrecording/ws golang.org/x/oauth2 from golang.org/x/oauth2/clientcredentials+ - golang.org/x/oauth2/clientcredentials from tailscale.com/cmd/k8s-operator+ + golang.org/x/oauth2/clientcredentials from tailscale.com/client/tailscale/v2+ golang.org/x/oauth2/internal from golang.org/x/oauth2+ golang.org/x/sync/errgroup from github.com/mdlayher/socket+ golang.org/x/sys/cpu from github.com/tailscale/certstore+ @@ -1079,7 +1005,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ crypto/aes from crypto/tls+ crypto/cipher from crypto/aes+ crypto/des from crypto/tls+ - crypto/dsa from crypto/x509+ + crypto/dsa from crypto/x509 crypto/ecdh from crypto/ecdsa+ crypto/ecdsa from crypto/tls+ crypto/ed25519 from crypto/tls+ @@ -1128,16 +1054,16 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ crypto/internal/randutil from crypto/internal/rand crypto/internal/sysrand from crypto/internal/fips140/drbg crypto/md5 from crypto/tls+ - crypto/mlkem from golang.org/x/crypto/ssh+ + crypto/mlkem from crypto/hpke+ crypto/rand from crypto/ed25519+ - crypto/rc4 from crypto/tls+ + crypto/rc4 from crypto/tls crypto/rsa from crypto/tls+ crypto/sha1 from crypto/tls+ crypto/sha256 from crypto/tls+ crypto/sha3 from crypto/internal/fips140hash+ crypto/sha512 from crypto/ecdsa+ crypto/subtle from crypto/cipher+ - crypto/tls from github.com/prometheus-community/pro-bing+ + crypto/tls from github.com/prometheus/client_golang/prometheus/promhttp+ crypto/tls/internal/fips140tls from crypto/tls crypto/x509 from crypto/tls+ D crypto/x509/internal/macos from crypto/x509 @@ -1246,7 +1172,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ mime/quotedprintable from mime/multipart net from crypto/tls+ net/http from expvar+ - net/http/httptrace from github.com/prometheus-community/pro-bing+ + net/http/httptrace from github.com/prometheus/client_golang/prometheus/promhttp+ net/http/httputil from tailscale.com/client/web+ net/http/internal from net/http+ net/http/internal/ascii from net/http+ diff --git a/cmd/k8s-operator/deploy/chart/templates/deployment.yaml b/cmd/k8s-operator/deploy/chart/templates/deployment.yaml index 0c0cb64cbb4ed..feffd03a39cf1 100644 --- a/cmd/k8s-operator/deploy/chart/templates/deployment.yaml +++ b/cmd/k8s-operator/deploy/chart/templates/deployment.yaml @@ -146,3 +146,6 @@ spec: tolerations: {{- toYaml . | nindent 8 }} {{- end }} + {{- with .Values.operatorConfig.priorityClassName }} + priorityClassName: {{ . }} + {{- end }} diff --git a/cmd/k8s-operator/deploy/chart/values.yaml b/cmd/k8s-operator/deploy/chart/values.yaml index 3606f1af3f3d2..185e7c34b1d79 100644 --- a/cmd/k8s-operator/deploy/chart/values.yaml +++ b/cmd/k8s-operator/deploy/chart/values.yaml @@ -72,6 +72,8 @@ operatorConfig: affinity: {} + priorityClassName: "" + podSecurityContext: {} securityContext: {} diff --git a/cmd/k8s-operator/deploy/crds/tailscale.com_dnsconfigs.yaml b/cmd/k8s-operator/deploy/crds/tailscale.com_dnsconfigs.yaml index a819aa6518684..4d6422ede46ec 100644 --- a/cmd/k8s-operator/deploy/crds/tailscale.com_dnsconfigs.yaml +++ b/cmd/k8s-operator/deploy/crds/tailscale.com_dnsconfigs.yaml @@ -104,6 +104,884 @@ spec: description: Pod configuration. type: object properties: + affinity: + description: If specified, applies affinity rules to the pods deployed by the DNSConfig resource. + type: object + properties: + nodeAffinity: + description: Describes node affinity scheduling rules for the pod. + type: object + properties: + preferredDuringSchedulingIgnoredDuringExecution: + description: |- + The scheduler will prefer to schedule pods to nodes that satisfy + the affinity expressions specified by this field, but it may choose + a node that violates one or more of the expressions. The node that is + most preferred is the one with the greatest sum of weights, i.e. + for each node that meets all of the scheduling requirements (resource + request, requiredDuringScheduling affinity expressions, etc.), + compute a sum by iterating through the elements of this field and adding + "weight" to the sum if the node matches the corresponding matchExpressions; the + node(s) with the highest sum are the most preferred. + type: array + items: + description: |- + An empty preferred scheduling term matches all objects with implicit weight 0 + (i.e. it's a no-op). A null preferred scheduling term matches no objects (i.e. is also a no-op). + type: object + required: + - preference + - weight + properties: + preference: + description: A node selector term, associated with the corresponding weight. + type: object + properties: + matchExpressions: + description: A list of node selector requirements by node's labels. + type: array + items: + description: |- + A node selector requirement is a selector that contains values, a key, and an operator + that relates the key and values. + type: object + required: + - key + - operator + properties: + key: + description: The label key that the selector applies to. + type: string + operator: + description: |- + Represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists, DoesNotExist. Gt, and Lt. + type: string + values: + description: |- + An array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. If the operator is Gt or Lt, the values + array must have a single element, which will be interpreted as an integer. + This array is replaced during a strategic merge patch. + type: array + items: + type: string + x-kubernetes-list-type: atomic + x-kubernetes-list-type: atomic + matchFields: + description: A list of node selector requirements by node's fields. + type: array + items: + description: |- + A node selector requirement is a selector that contains values, a key, and an operator + that relates the key and values. + type: object + required: + - key + - operator + properties: + key: + description: The label key that the selector applies to. + type: string + operator: + description: |- + Represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists, DoesNotExist. Gt, and Lt. + type: string + values: + description: |- + An array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. If the operator is Gt or Lt, the values + array must have a single element, which will be interpreted as an integer. + This array is replaced during a strategic merge patch. + type: array + items: + type: string + x-kubernetes-list-type: atomic + x-kubernetes-list-type: atomic + x-kubernetes-map-type: atomic + weight: + description: Weight associated with matching the corresponding nodeSelectorTerm, in the range 1-100. + type: integer + format: int32 + x-kubernetes-list-type: atomic + requiredDuringSchedulingIgnoredDuringExecution: + description: |- + If the affinity requirements specified by this field are not met at + scheduling time, the pod will not be scheduled onto the node. + If the affinity requirements specified by this field cease to be met + at some point during pod execution (e.g. due to an update), the system + may or may not try to eventually evict the pod from its node. + type: object + required: + - nodeSelectorTerms + properties: + nodeSelectorTerms: + description: Required. A list of node selector terms. The terms are ORed. + type: array + items: + description: |- + A null or empty node selector term matches no objects. The requirements of + them are ANDed. + The TopologySelectorTerm type implements a subset of the NodeSelectorTerm. + type: object + properties: + matchExpressions: + description: A list of node selector requirements by node's labels. + type: array + items: + description: |- + A node selector requirement is a selector that contains values, a key, and an operator + that relates the key and values. + type: object + required: + - key + - operator + properties: + key: + description: The label key that the selector applies to. + type: string + operator: + description: |- + Represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists, DoesNotExist. Gt, and Lt. + type: string + values: + description: |- + An array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. If the operator is Gt or Lt, the values + array must have a single element, which will be interpreted as an integer. + This array is replaced during a strategic merge patch. + type: array + items: + type: string + x-kubernetes-list-type: atomic + x-kubernetes-list-type: atomic + matchFields: + description: A list of node selector requirements by node's fields. + type: array + items: + description: |- + A node selector requirement is a selector that contains values, a key, and an operator + that relates the key and values. + type: object + required: + - key + - operator + properties: + key: + description: The label key that the selector applies to. + type: string + operator: + description: |- + Represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists, DoesNotExist. Gt, and Lt. + type: string + values: + description: |- + An array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. If the operator is Gt or Lt, the values + array must have a single element, which will be interpreted as an integer. + This array is replaced during a strategic merge patch. + type: array + items: + type: string + x-kubernetes-list-type: atomic + x-kubernetes-list-type: atomic + x-kubernetes-map-type: atomic + x-kubernetes-list-type: atomic + x-kubernetes-map-type: atomic + podAffinity: + description: Describes pod affinity scheduling rules (e.g. co-locate this pod in the same node, zone, etc. as some other pod(s)). + type: object + properties: + preferredDuringSchedulingIgnoredDuringExecution: + description: |- + The scheduler will prefer to schedule pods to nodes that satisfy + the affinity expressions specified by this field, but it may choose + a node that violates one or more of the expressions. The node that is + most preferred is the one with the greatest sum of weights, i.e. + for each node that meets all of the scheduling requirements (resource + request, requiredDuringScheduling affinity expressions, etc.), + compute a sum by iterating through the elements of this field and adding + "weight" to the sum if the node has pods which matches the corresponding podAffinityTerm; the + node(s) with the highest sum are the most preferred. + type: array + items: + description: The weights of all of the matched WeightedPodAffinityTerm fields are added per-node to find the most preferred node(s) + type: object + required: + - podAffinityTerm + - weight + properties: + podAffinityTerm: + description: Required. A pod affinity term, associated with the corresponding weight. + type: object + required: + - topologyKey + properties: + labelSelector: + description: |- + A label query over a set of resources, in this case pods. + If it's null, this PodAffinityTerm matches with no Pods. + type: object + properties: + matchExpressions: + description: matchExpressions is a list of label selector requirements. The requirements are ANDed. + type: array + items: + description: |- + A label selector requirement is a selector that contains values, a key, and an operator that + relates the key and values. + type: object + required: + - key + - operator + properties: + key: + description: key is the label key that the selector applies to. + type: string + operator: + description: |- + operator represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists and DoesNotExist. + type: string + values: + description: |- + values is an array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. This array is replaced during a strategic + merge patch. + type: array + items: + type: string + x-kubernetes-list-type: atomic + x-kubernetes-list-type: atomic + matchLabels: + description: |- + matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels + map is equivalent to an element of matchExpressions, whose key field is "key", the + operator is "In", and the values array contains only "value". The requirements are ANDed. + type: object + additionalProperties: + type: string + x-kubernetes-map-type: atomic + matchLabelKeys: + description: |- + MatchLabelKeys is a set of pod label keys to select which pods will + be taken into consideration. The keys are used to lookup values from the + incoming pod labels, those key-value labels are merged with `labelSelector` as `key in (value)` + to select the group of existing pods which pods will be taken into consideration + for the incoming pod's pod (anti) affinity. Keys that don't exist in the incoming + pod labels will be ignored. The default value is empty. + The same key is forbidden to exist in both matchLabelKeys and labelSelector. + Also, matchLabelKeys cannot be set when labelSelector isn't set. + type: array + items: + type: string + x-kubernetes-list-type: atomic + mismatchLabelKeys: + description: |- + MismatchLabelKeys is a set of pod label keys to select which pods will + be taken into consideration. The keys are used to lookup values from the + incoming pod labels, those key-value labels are merged with `labelSelector` as `key notin (value)` + to select the group of existing pods which pods will be taken into consideration + for the incoming pod's pod (anti) affinity. Keys that don't exist in the incoming + pod labels will be ignored. The default value is empty. + The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. + Also, mismatchLabelKeys cannot be set when labelSelector isn't set. + type: array + items: + type: string + x-kubernetes-list-type: atomic + namespaceSelector: + description: |- + A label query over the set of namespaces that the term applies to. + The term is applied to the union of the namespaces selected by this field + and the ones listed in the namespaces field. + null selector and null or empty namespaces list means "this pod's namespace". + An empty selector ({}) matches all namespaces. + type: object + properties: + matchExpressions: + description: matchExpressions is a list of label selector requirements. The requirements are ANDed. + type: array + items: + description: |- + A label selector requirement is a selector that contains values, a key, and an operator that + relates the key and values. + type: object + required: + - key + - operator + properties: + key: + description: key is the label key that the selector applies to. + type: string + operator: + description: |- + operator represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists and DoesNotExist. + type: string + values: + description: |- + values is an array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. This array is replaced during a strategic + merge patch. + type: array + items: + type: string + x-kubernetes-list-type: atomic + x-kubernetes-list-type: atomic + matchLabels: + description: |- + matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels + map is equivalent to an element of matchExpressions, whose key field is "key", the + operator is "In", and the values array contains only "value". The requirements are ANDed. + type: object + additionalProperties: + type: string + x-kubernetes-map-type: atomic + namespaces: + description: |- + namespaces specifies a static list of namespace names that the term applies to. + The term is applied to the union of the namespaces listed in this field + and the ones selected by namespaceSelector. + null or empty namespaces list and null namespaceSelector means "this pod's namespace". + type: array + items: + type: string + x-kubernetes-list-type: atomic + topologyKey: + description: |- + This pod should be co-located (affinity) or not co-located (anti-affinity) with the pods matching + the labelSelector in the specified namespaces, where co-located is defined as running on a node + whose value of the label with key topologyKey matches that of any node on which any of the + selected pods is running. + Empty topologyKey is not allowed. + type: string + weight: + description: |- + weight associated with matching the corresponding podAffinityTerm, + in the range 1-100. + type: integer + format: int32 + x-kubernetes-list-type: atomic + requiredDuringSchedulingIgnoredDuringExecution: + description: |- + If the affinity requirements specified by this field are not met at + scheduling time, the pod will not be scheduled onto the node. + If the affinity requirements specified by this field cease to be met + at some point during pod execution (e.g. due to a pod label update), the + system may or may not try to eventually evict the pod from its node. + When there are multiple elements, the lists of nodes corresponding to each + podAffinityTerm are intersected, i.e. all terms must be satisfied. + type: array + items: + description: |- + Defines a set of pods (namely those matching the labelSelector + relative to the given namespace(s)) that this pod should be + co-located (affinity) or not co-located (anti-affinity) with, + where co-located is defined as running on a node whose value of + the label with key matches that of any node on which + a pod of the set of pods is running + type: object + required: + - topologyKey + properties: + labelSelector: + description: |- + A label query over a set of resources, in this case pods. + If it's null, this PodAffinityTerm matches with no Pods. + type: object + properties: + matchExpressions: + description: matchExpressions is a list of label selector requirements. The requirements are ANDed. + type: array + items: + description: |- + A label selector requirement is a selector that contains values, a key, and an operator that + relates the key and values. + type: object + required: + - key + - operator + properties: + key: + description: key is the label key that the selector applies to. + type: string + operator: + description: |- + operator represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists and DoesNotExist. + type: string + values: + description: |- + values is an array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. This array is replaced during a strategic + merge patch. + type: array + items: + type: string + x-kubernetes-list-type: atomic + x-kubernetes-list-type: atomic + matchLabels: + description: |- + matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels + map is equivalent to an element of matchExpressions, whose key field is "key", the + operator is "In", and the values array contains only "value". The requirements are ANDed. + type: object + additionalProperties: + type: string + x-kubernetes-map-type: atomic + matchLabelKeys: + description: |- + MatchLabelKeys is a set of pod label keys to select which pods will + be taken into consideration. The keys are used to lookup values from the + incoming pod labels, those key-value labels are merged with `labelSelector` as `key in (value)` + to select the group of existing pods which pods will be taken into consideration + for the incoming pod's pod (anti) affinity. Keys that don't exist in the incoming + pod labels will be ignored. The default value is empty. + The same key is forbidden to exist in both matchLabelKeys and labelSelector. + Also, matchLabelKeys cannot be set when labelSelector isn't set. + type: array + items: + type: string + x-kubernetes-list-type: atomic + mismatchLabelKeys: + description: |- + MismatchLabelKeys is a set of pod label keys to select which pods will + be taken into consideration. The keys are used to lookup values from the + incoming pod labels, those key-value labels are merged with `labelSelector` as `key notin (value)` + to select the group of existing pods which pods will be taken into consideration + for the incoming pod's pod (anti) affinity. Keys that don't exist in the incoming + pod labels will be ignored. The default value is empty. + The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. + Also, mismatchLabelKeys cannot be set when labelSelector isn't set. + type: array + items: + type: string + x-kubernetes-list-type: atomic + namespaceSelector: + description: |- + A label query over the set of namespaces that the term applies to. + The term is applied to the union of the namespaces selected by this field + and the ones listed in the namespaces field. + null selector and null or empty namespaces list means "this pod's namespace". + An empty selector ({}) matches all namespaces. + type: object + properties: + matchExpressions: + description: matchExpressions is a list of label selector requirements. The requirements are ANDed. + type: array + items: + description: |- + A label selector requirement is a selector that contains values, a key, and an operator that + relates the key and values. + type: object + required: + - key + - operator + properties: + key: + description: key is the label key that the selector applies to. + type: string + operator: + description: |- + operator represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists and DoesNotExist. + type: string + values: + description: |- + values is an array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. This array is replaced during a strategic + merge patch. + type: array + items: + type: string + x-kubernetes-list-type: atomic + x-kubernetes-list-type: atomic + matchLabels: + description: |- + matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels + map is equivalent to an element of matchExpressions, whose key field is "key", the + operator is "In", and the values array contains only "value". The requirements are ANDed. + type: object + additionalProperties: + type: string + x-kubernetes-map-type: atomic + namespaces: + description: |- + namespaces specifies a static list of namespace names that the term applies to. + The term is applied to the union of the namespaces listed in this field + and the ones selected by namespaceSelector. + null or empty namespaces list and null namespaceSelector means "this pod's namespace". + type: array + items: + type: string + x-kubernetes-list-type: atomic + topologyKey: + description: |- + This pod should be co-located (affinity) or not co-located (anti-affinity) with the pods matching + the labelSelector in the specified namespaces, where co-located is defined as running on a node + whose value of the label with key topologyKey matches that of any node on which any of the + selected pods is running. + Empty topologyKey is not allowed. + type: string + x-kubernetes-list-type: atomic + podAntiAffinity: + description: Describes pod anti-affinity scheduling rules (e.g. avoid putting this pod in the same node, zone, etc. as some other pod(s)). + type: object + properties: + preferredDuringSchedulingIgnoredDuringExecution: + description: |- + The scheduler will prefer to schedule pods to nodes that satisfy + the anti-affinity expressions specified by this field, but it may choose + a node that violates one or more of the expressions. The node that is + most preferred is the one with the greatest sum of weights, i.e. + for each node that meets all of the scheduling requirements (resource + request, requiredDuringScheduling anti-affinity expressions, etc.), + compute a sum by iterating through the elements of this field and subtracting + "weight" from the sum if the node has pods which matches the corresponding podAffinityTerm; the + node(s) with the highest sum are the most preferred. + type: array + items: + description: The weights of all of the matched WeightedPodAffinityTerm fields are added per-node to find the most preferred node(s) + type: object + required: + - podAffinityTerm + - weight + properties: + podAffinityTerm: + description: Required. A pod affinity term, associated with the corresponding weight. + type: object + required: + - topologyKey + properties: + labelSelector: + description: |- + A label query over a set of resources, in this case pods. + If it's null, this PodAffinityTerm matches with no Pods. + type: object + properties: + matchExpressions: + description: matchExpressions is a list of label selector requirements. The requirements are ANDed. + type: array + items: + description: |- + A label selector requirement is a selector that contains values, a key, and an operator that + relates the key and values. + type: object + required: + - key + - operator + properties: + key: + description: key is the label key that the selector applies to. + type: string + operator: + description: |- + operator represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists and DoesNotExist. + type: string + values: + description: |- + values is an array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. This array is replaced during a strategic + merge patch. + type: array + items: + type: string + x-kubernetes-list-type: atomic + x-kubernetes-list-type: atomic + matchLabels: + description: |- + matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels + map is equivalent to an element of matchExpressions, whose key field is "key", the + operator is "In", and the values array contains only "value". The requirements are ANDed. + type: object + additionalProperties: + type: string + x-kubernetes-map-type: atomic + matchLabelKeys: + description: |- + MatchLabelKeys is a set of pod label keys to select which pods will + be taken into consideration. The keys are used to lookup values from the + incoming pod labels, those key-value labels are merged with `labelSelector` as `key in (value)` + to select the group of existing pods which pods will be taken into consideration + for the incoming pod's pod (anti) affinity. Keys that don't exist in the incoming + pod labels will be ignored. The default value is empty. + The same key is forbidden to exist in both matchLabelKeys and labelSelector. + Also, matchLabelKeys cannot be set when labelSelector isn't set. + type: array + items: + type: string + x-kubernetes-list-type: atomic + mismatchLabelKeys: + description: |- + MismatchLabelKeys is a set of pod label keys to select which pods will + be taken into consideration. The keys are used to lookup values from the + incoming pod labels, those key-value labels are merged with `labelSelector` as `key notin (value)` + to select the group of existing pods which pods will be taken into consideration + for the incoming pod's pod (anti) affinity. Keys that don't exist in the incoming + pod labels will be ignored. The default value is empty. + The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. + Also, mismatchLabelKeys cannot be set when labelSelector isn't set. + type: array + items: + type: string + x-kubernetes-list-type: atomic + namespaceSelector: + description: |- + A label query over the set of namespaces that the term applies to. + The term is applied to the union of the namespaces selected by this field + and the ones listed in the namespaces field. + null selector and null or empty namespaces list means "this pod's namespace". + An empty selector ({}) matches all namespaces. + type: object + properties: + matchExpressions: + description: matchExpressions is a list of label selector requirements. The requirements are ANDed. + type: array + items: + description: |- + A label selector requirement is a selector that contains values, a key, and an operator that + relates the key and values. + type: object + required: + - key + - operator + properties: + key: + description: key is the label key that the selector applies to. + type: string + operator: + description: |- + operator represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists and DoesNotExist. + type: string + values: + description: |- + values is an array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. This array is replaced during a strategic + merge patch. + type: array + items: + type: string + x-kubernetes-list-type: atomic + x-kubernetes-list-type: atomic + matchLabels: + description: |- + matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels + map is equivalent to an element of matchExpressions, whose key field is "key", the + operator is "In", and the values array contains only "value". The requirements are ANDed. + type: object + additionalProperties: + type: string + x-kubernetes-map-type: atomic + namespaces: + description: |- + namespaces specifies a static list of namespace names that the term applies to. + The term is applied to the union of the namespaces listed in this field + and the ones selected by namespaceSelector. + null or empty namespaces list and null namespaceSelector means "this pod's namespace". + type: array + items: + type: string + x-kubernetes-list-type: atomic + topologyKey: + description: |- + This pod should be co-located (affinity) or not co-located (anti-affinity) with the pods matching + the labelSelector in the specified namespaces, where co-located is defined as running on a node + whose value of the label with key topologyKey matches that of any node on which any of the + selected pods is running. + Empty topologyKey is not allowed. + type: string + weight: + description: |- + weight associated with matching the corresponding podAffinityTerm, + in the range 1-100. + type: integer + format: int32 + x-kubernetes-list-type: atomic + requiredDuringSchedulingIgnoredDuringExecution: + description: |- + If the anti-affinity requirements specified by this field are not met at + scheduling time, the pod will not be scheduled onto the node. + If the anti-affinity requirements specified by this field cease to be met + at some point during pod execution (e.g. due to a pod label update), the + system may or may not try to eventually evict the pod from its node. + When there are multiple elements, the lists of nodes corresponding to each + podAffinityTerm are intersected, i.e. all terms must be satisfied. + type: array + items: + description: |- + Defines a set of pods (namely those matching the labelSelector + relative to the given namespace(s)) that this pod should be + co-located (affinity) or not co-located (anti-affinity) with, + where co-located is defined as running on a node whose value of + the label with key matches that of any node on which + a pod of the set of pods is running + type: object + required: + - topologyKey + properties: + labelSelector: + description: |- + A label query over a set of resources, in this case pods. + If it's null, this PodAffinityTerm matches with no Pods. + type: object + properties: + matchExpressions: + description: matchExpressions is a list of label selector requirements. The requirements are ANDed. + type: array + items: + description: |- + A label selector requirement is a selector that contains values, a key, and an operator that + relates the key and values. + type: object + required: + - key + - operator + properties: + key: + description: key is the label key that the selector applies to. + type: string + operator: + description: |- + operator represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists and DoesNotExist. + type: string + values: + description: |- + values is an array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. This array is replaced during a strategic + merge patch. + type: array + items: + type: string + x-kubernetes-list-type: atomic + x-kubernetes-list-type: atomic + matchLabels: + description: |- + matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels + map is equivalent to an element of matchExpressions, whose key field is "key", the + operator is "In", and the values array contains only "value". The requirements are ANDed. + type: object + additionalProperties: + type: string + x-kubernetes-map-type: atomic + matchLabelKeys: + description: |- + MatchLabelKeys is a set of pod label keys to select which pods will + be taken into consideration. The keys are used to lookup values from the + incoming pod labels, those key-value labels are merged with `labelSelector` as `key in (value)` + to select the group of existing pods which pods will be taken into consideration + for the incoming pod's pod (anti) affinity. Keys that don't exist in the incoming + pod labels will be ignored. The default value is empty. + The same key is forbidden to exist in both matchLabelKeys and labelSelector. + Also, matchLabelKeys cannot be set when labelSelector isn't set. + type: array + items: + type: string + x-kubernetes-list-type: atomic + mismatchLabelKeys: + description: |- + MismatchLabelKeys is a set of pod label keys to select which pods will + be taken into consideration. The keys are used to lookup values from the + incoming pod labels, those key-value labels are merged with `labelSelector` as `key notin (value)` + to select the group of existing pods which pods will be taken into consideration + for the incoming pod's pod (anti) affinity. Keys that don't exist in the incoming + pod labels will be ignored. The default value is empty. + The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. + Also, mismatchLabelKeys cannot be set when labelSelector isn't set. + type: array + items: + type: string + x-kubernetes-list-type: atomic + namespaceSelector: + description: |- + A label query over the set of namespaces that the term applies to. + The term is applied to the union of the namespaces selected by this field + and the ones listed in the namespaces field. + null selector and null or empty namespaces list means "this pod's namespace". + An empty selector ({}) matches all namespaces. + type: object + properties: + matchExpressions: + description: matchExpressions is a list of label selector requirements. The requirements are ANDed. + type: array + items: + description: |- + A label selector requirement is a selector that contains values, a key, and an operator that + relates the key and values. + type: object + required: + - key + - operator + properties: + key: + description: key is the label key that the selector applies to. + type: string + operator: + description: |- + operator represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists and DoesNotExist. + type: string + values: + description: |- + values is an array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. This array is replaced during a strategic + merge patch. + type: array + items: + type: string + x-kubernetes-list-type: atomic + x-kubernetes-list-type: atomic + matchLabels: + description: |- + matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels + map is equivalent to an element of matchExpressions, whose key field is "key", the + operator is "In", and the values array contains only "value". The requirements are ANDed. + type: object + additionalProperties: + type: string + x-kubernetes-map-type: atomic + namespaces: + description: |- + namespaces specifies a static list of namespace names that the term applies to. + The term is applied to the union of the namespaces listed in this field + and the ones selected by namespaceSelector. + null or empty namespaces list and null namespaceSelector means "this pod's namespace". + type: array + items: + type: string + x-kubernetes-list-type: atomic + topologyKey: + description: |- + This pod should be co-located (affinity) or not co-located (anti-affinity) with the pods matching + the labelSelector in the specified namespaces, where co-located is defined as running on a node + whose value of the label with key topologyKey matches that of any node on which any of the + selected pods is running. + Empty topologyKey is not allowed. + type: string + x-kubernetes-list-type: atomic + nodeSelector: + description: If specified, applies node selector rules to the pods deployed by the DNSConfig resource. + type: object + additionalProperties: + type: string tolerations: description: If specified, applies tolerations to the pods deployed by the DNSConfig resource. type: array diff --git a/cmd/k8s-operator/deploy/manifests/operator.yaml b/cmd/k8s-operator/deploy/manifests/operator.yaml index 597641bdefecf..07c9f3af3d307 100644 --- a/cmd/k8s-operator/deploy/manifests/operator.yaml +++ b/cmd/k8s-operator/deploy/manifests/operator.yaml @@ -442,6 +442,884 @@ spec: pod: description: Pod configuration. properties: + affinity: + description: If specified, applies affinity rules to the pods deployed by the DNSConfig resource. + properties: + nodeAffinity: + description: Describes node affinity scheduling rules for the pod. + properties: + preferredDuringSchedulingIgnoredDuringExecution: + description: |- + The scheduler will prefer to schedule pods to nodes that satisfy + the affinity expressions specified by this field, but it may choose + a node that violates one or more of the expressions. The node that is + most preferred is the one with the greatest sum of weights, i.e. + for each node that meets all of the scheduling requirements (resource + request, requiredDuringScheduling affinity expressions, etc.), + compute a sum by iterating through the elements of this field and adding + "weight" to the sum if the node matches the corresponding matchExpressions; the + node(s) with the highest sum are the most preferred. + items: + description: |- + An empty preferred scheduling term matches all objects with implicit weight 0 + (i.e. it's a no-op). A null preferred scheduling term matches no objects (i.e. is also a no-op). + properties: + preference: + description: A node selector term, associated with the corresponding weight. + properties: + matchExpressions: + description: A list of node selector requirements by node's labels. + items: + description: |- + A node selector requirement is a selector that contains values, a key, and an operator + that relates the key and values. + properties: + key: + description: The label key that the selector applies to. + type: string + operator: + description: |- + Represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists, DoesNotExist. Gt, and Lt. + type: string + values: + description: |- + An array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. If the operator is Gt or Lt, the values + array must have a single element, which will be interpreted as an integer. + This array is replaced during a strategic merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + matchFields: + description: A list of node selector requirements by node's fields. + items: + description: |- + A node selector requirement is a selector that contains values, a key, and an operator + that relates the key and values. + properties: + key: + description: The label key that the selector applies to. + type: string + operator: + description: |- + Represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists, DoesNotExist. Gt, and Lt. + type: string + values: + description: |- + An array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. If the operator is Gt or Lt, the values + array must have a single element, which will be interpreted as an integer. + This array is replaced during a strategic merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + type: object + x-kubernetes-map-type: atomic + weight: + description: Weight associated with matching the corresponding nodeSelectorTerm, in the range 1-100. + format: int32 + type: integer + required: + - preference + - weight + type: object + type: array + x-kubernetes-list-type: atomic + requiredDuringSchedulingIgnoredDuringExecution: + description: |- + If the affinity requirements specified by this field are not met at + scheduling time, the pod will not be scheduled onto the node. + If the affinity requirements specified by this field cease to be met + at some point during pod execution (e.g. due to an update), the system + may or may not try to eventually evict the pod from its node. + properties: + nodeSelectorTerms: + description: Required. A list of node selector terms. The terms are ORed. + items: + description: |- + A null or empty node selector term matches no objects. The requirements of + them are ANDed. + The TopologySelectorTerm type implements a subset of the NodeSelectorTerm. + properties: + matchExpressions: + description: A list of node selector requirements by node's labels. + items: + description: |- + A node selector requirement is a selector that contains values, a key, and an operator + that relates the key and values. + properties: + key: + description: The label key that the selector applies to. + type: string + operator: + description: |- + Represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists, DoesNotExist. Gt, and Lt. + type: string + values: + description: |- + An array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. If the operator is Gt or Lt, the values + array must have a single element, which will be interpreted as an integer. + This array is replaced during a strategic merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + matchFields: + description: A list of node selector requirements by node's fields. + items: + description: |- + A node selector requirement is a selector that contains values, a key, and an operator + that relates the key and values. + properties: + key: + description: The label key that the selector applies to. + type: string + operator: + description: |- + Represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists, DoesNotExist. Gt, and Lt. + type: string + values: + description: |- + An array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. If the operator is Gt or Lt, the values + array must have a single element, which will be interpreted as an integer. + This array is replaced during a strategic merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + type: object + x-kubernetes-map-type: atomic + type: array + x-kubernetes-list-type: atomic + required: + - nodeSelectorTerms + type: object + x-kubernetes-map-type: atomic + type: object + podAffinity: + description: Describes pod affinity scheduling rules (e.g. co-locate this pod in the same node, zone, etc. as some other pod(s)). + properties: + preferredDuringSchedulingIgnoredDuringExecution: + description: |- + The scheduler will prefer to schedule pods to nodes that satisfy + the affinity expressions specified by this field, but it may choose + a node that violates one or more of the expressions. The node that is + most preferred is the one with the greatest sum of weights, i.e. + for each node that meets all of the scheduling requirements (resource + request, requiredDuringScheduling affinity expressions, etc.), + compute a sum by iterating through the elements of this field and adding + "weight" to the sum if the node has pods which matches the corresponding podAffinityTerm; the + node(s) with the highest sum are the most preferred. + items: + description: The weights of all of the matched WeightedPodAffinityTerm fields are added per-node to find the most preferred node(s) + properties: + podAffinityTerm: + description: Required. A pod affinity term, associated with the corresponding weight. + properties: + labelSelector: + description: |- + A label query over a set of resources, in this case pods. + If it's null, this PodAffinityTerm matches with no Pods. + properties: + matchExpressions: + description: matchExpressions is a list of label selector requirements. The requirements are ANDed. + items: + description: |- + A label selector requirement is a selector that contains values, a key, and an operator that + relates the key and values. + properties: + key: + description: key is the label key that the selector applies to. + type: string + operator: + description: |- + operator represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists and DoesNotExist. + type: string + values: + description: |- + values is an array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. This array is replaced during a strategic + merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + matchLabels: + additionalProperties: + type: string + description: |- + matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels + map is equivalent to an element of matchExpressions, whose key field is "key", the + operator is "In", and the values array contains only "value". The requirements are ANDed. + type: object + type: object + x-kubernetes-map-type: atomic + matchLabelKeys: + description: |- + MatchLabelKeys is a set of pod label keys to select which pods will + be taken into consideration. The keys are used to lookup values from the + incoming pod labels, those key-value labels are merged with `labelSelector` as `key in (value)` + to select the group of existing pods which pods will be taken into consideration + for the incoming pod's pod (anti) affinity. Keys that don't exist in the incoming + pod labels will be ignored. The default value is empty. + The same key is forbidden to exist in both matchLabelKeys and labelSelector. + Also, matchLabelKeys cannot be set when labelSelector isn't set. + items: + type: string + type: array + x-kubernetes-list-type: atomic + mismatchLabelKeys: + description: |- + MismatchLabelKeys is a set of pod label keys to select which pods will + be taken into consideration. The keys are used to lookup values from the + incoming pod labels, those key-value labels are merged with `labelSelector` as `key notin (value)` + to select the group of existing pods which pods will be taken into consideration + for the incoming pod's pod (anti) affinity. Keys that don't exist in the incoming + pod labels will be ignored. The default value is empty. + The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. + Also, mismatchLabelKeys cannot be set when labelSelector isn't set. + items: + type: string + type: array + x-kubernetes-list-type: atomic + namespaceSelector: + description: |- + A label query over the set of namespaces that the term applies to. + The term is applied to the union of the namespaces selected by this field + and the ones listed in the namespaces field. + null selector and null or empty namespaces list means "this pod's namespace". + An empty selector ({}) matches all namespaces. + properties: + matchExpressions: + description: matchExpressions is a list of label selector requirements. The requirements are ANDed. + items: + description: |- + A label selector requirement is a selector that contains values, a key, and an operator that + relates the key and values. + properties: + key: + description: key is the label key that the selector applies to. + type: string + operator: + description: |- + operator represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists and DoesNotExist. + type: string + values: + description: |- + values is an array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. This array is replaced during a strategic + merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + matchLabels: + additionalProperties: + type: string + description: |- + matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels + map is equivalent to an element of matchExpressions, whose key field is "key", the + operator is "In", and the values array contains only "value". The requirements are ANDed. + type: object + type: object + x-kubernetes-map-type: atomic + namespaces: + description: |- + namespaces specifies a static list of namespace names that the term applies to. + The term is applied to the union of the namespaces listed in this field + and the ones selected by namespaceSelector. + null or empty namespaces list and null namespaceSelector means "this pod's namespace". + items: + type: string + type: array + x-kubernetes-list-type: atomic + topologyKey: + description: |- + This pod should be co-located (affinity) or not co-located (anti-affinity) with the pods matching + the labelSelector in the specified namespaces, where co-located is defined as running on a node + whose value of the label with key topologyKey matches that of any node on which any of the + selected pods is running. + Empty topologyKey is not allowed. + type: string + required: + - topologyKey + type: object + weight: + description: |- + weight associated with matching the corresponding podAffinityTerm, + in the range 1-100. + format: int32 + type: integer + required: + - podAffinityTerm + - weight + type: object + type: array + x-kubernetes-list-type: atomic + requiredDuringSchedulingIgnoredDuringExecution: + description: |- + If the affinity requirements specified by this field are not met at + scheduling time, the pod will not be scheduled onto the node. + If the affinity requirements specified by this field cease to be met + at some point during pod execution (e.g. due to a pod label update), the + system may or may not try to eventually evict the pod from its node. + When there are multiple elements, the lists of nodes corresponding to each + podAffinityTerm are intersected, i.e. all terms must be satisfied. + items: + description: |- + Defines a set of pods (namely those matching the labelSelector + relative to the given namespace(s)) that this pod should be + co-located (affinity) or not co-located (anti-affinity) with, + where co-located is defined as running on a node whose value of + the label with key matches that of any node on which + a pod of the set of pods is running + properties: + labelSelector: + description: |- + A label query over a set of resources, in this case pods. + If it's null, this PodAffinityTerm matches with no Pods. + properties: + matchExpressions: + description: matchExpressions is a list of label selector requirements. The requirements are ANDed. + items: + description: |- + A label selector requirement is a selector that contains values, a key, and an operator that + relates the key and values. + properties: + key: + description: key is the label key that the selector applies to. + type: string + operator: + description: |- + operator represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists and DoesNotExist. + type: string + values: + description: |- + values is an array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. This array is replaced during a strategic + merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + matchLabels: + additionalProperties: + type: string + description: |- + matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels + map is equivalent to an element of matchExpressions, whose key field is "key", the + operator is "In", and the values array contains only "value". The requirements are ANDed. + type: object + type: object + x-kubernetes-map-type: atomic + matchLabelKeys: + description: |- + MatchLabelKeys is a set of pod label keys to select which pods will + be taken into consideration. The keys are used to lookup values from the + incoming pod labels, those key-value labels are merged with `labelSelector` as `key in (value)` + to select the group of existing pods which pods will be taken into consideration + for the incoming pod's pod (anti) affinity. Keys that don't exist in the incoming + pod labels will be ignored. The default value is empty. + The same key is forbidden to exist in both matchLabelKeys and labelSelector. + Also, matchLabelKeys cannot be set when labelSelector isn't set. + items: + type: string + type: array + x-kubernetes-list-type: atomic + mismatchLabelKeys: + description: |- + MismatchLabelKeys is a set of pod label keys to select which pods will + be taken into consideration. The keys are used to lookup values from the + incoming pod labels, those key-value labels are merged with `labelSelector` as `key notin (value)` + to select the group of existing pods which pods will be taken into consideration + for the incoming pod's pod (anti) affinity. Keys that don't exist in the incoming + pod labels will be ignored. The default value is empty. + The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. + Also, mismatchLabelKeys cannot be set when labelSelector isn't set. + items: + type: string + type: array + x-kubernetes-list-type: atomic + namespaceSelector: + description: |- + A label query over the set of namespaces that the term applies to. + The term is applied to the union of the namespaces selected by this field + and the ones listed in the namespaces field. + null selector and null or empty namespaces list means "this pod's namespace". + An empty selector ({}) matches all namespaces. + properties: + matchExpressions: + description: matchExpressions is a list of label selector requirements. The requirements are ANDed. + items: + description: |- + A label selector requirement is a selector that contains values, a key, and an operator that + relates the key and values. + properties: + key: + description: key is the label key that the selector applies to. + type: string + operator: + description: |- + operator represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists and DoesNotExist. + type: string + values: + description: |- + values is an array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. This array is replaced during a strategic + merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + matchLabels: + additionalProperties: + type: string + description: |- + matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels + map is equivalent to an element of matchExpressions, whose key field is "key", the + operator is "In", and the values array contains only "value". The requirements are ANDed. + type: object + type: object + x-kubernetes-map-type: atomic + namespaces: + description: |- + namespaces specifies a static list of namespace names that the term applies to. + The term is applied to the union of the namespaces listed in this field + and the ones selected by namespaceSelector. + null or empty namespaces list and null namespaceSelector means "this pod's namespace". + items: + type: string + type: array + x-kubernetes-list-type: atomic + topologyKey: + description: |- + This pod should be co-located (affinity) or not co-located (anti-affinity) with the pods matching + the labelSelector in the specified namespaces, where co-located is defined as running on a node + whose value of the label with key topologyKey matches that of any node on which any of the + selected pods is running. + Empty topologyKey is not allowed. + type: string + required: + - topologyKey + type: object + type: array + x-kubernetes-list-type: atomic + type: object + podAntiAffinity: + description: Describes pod anti-affinity scheduling rules (e.g. avoid putting this pod in the same node, zone, etc. as some other pod(s)). + properties: + preferredDuringSchedulingIgnoredDuringExecution: + description: |- + The scheduler will prefer to schedule pods to nodes that satisfy + the anti-affinity expressions specified by this field, but it may choose + a node that violates one or more of the expressions. The node that is + most preferred is the one with the greatest sum of weights, i.e. + for each node that meets all of the scheduling requirements (resource + request, requiredDuringScheduling anti-affinity expressions, etc.), + compute a sum by iterating through the elements of this field and subtracting + "weight" from the sum if the node has pods which matches the corresponding podAffinityTerm; the + node(s) with the highest sum are the most preferred. + items: + description: The weights of all of the matched WeightedPodAffinityTerm fields are added per-node to find the most preferred node(s) + properties: + podAffinityTerm: + description: Required. A pod affinity term, associated with the corresponding weight. + properties: + labelSelector: + description: |- + A label query over a set of resources, in this case pods. + If it's null, this PodAffinityTerm matches with no Pods. + properties: + matchExpressions: + description: matchExpressions is a list of label selector requirements. The requirements are ANDed. + items: + description: |- + A label selector requirement is a selector that contains values, a key, and an operator that + relates the key and values. + properties: + key: + description: key is the label key that the selector applies to. + type: string + operator: + description: |- + operator represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists and DoesNotExist. + type: string + values: + description: |- + values is an array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. This array is replaced during a strategic + merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + matchLabels: + additionalProperties: + type: string + description: |- + matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels + map is equivalent to an element of matchExpressions, whose key field is "key", the + operator is "In", and the values array contains only "value". The requirements are ANDed. + type: object + type: object + x-kubernetes-map-type: atomic + matchLabelKeys: + description: |- + MatchLabelKeys is a set of pod label keys to select which pods will + be taken into consideration. The keys are used to lookup values from the + incoming pod labels, those key-value labels are merged with `labelSelector` as `key in (value)` + to select the group of existing pods which pods will be taken into consideration + for the incoming pod's pod (anti) affinity. Keys that don't exist in the incoming + pod labels will be ignored. The default value is empty. + The same key is forbidden to exist in both matchLabelKeys and labelSelector. + Also, matchLabelKeys cannot be set when labelSelector isn't set. + items: + type: string + type: array + x-kubernetes-list-type: atomic + mismatchLabelKeys: + description: |- + MismatchLabelKeys is a set of pod label keys to select which pods will + be taken into consideration. The keys are used to lookup values from the + incoming pod labels, those key-value labels are merged with `labelSelector` as `key notin (value)` + to select the group of existing pods which pods will be taken into consideration + for the incoming pod's pod (anti) affinity. Keys that don't exist in the incoming + pod labels will be ignored. The default value is empty. + The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. + Also, mismatchLabelKeys cannot be set when labelSelector isn't set. + items: + type: string + type: array + x-kubernetes-list-type: atomic + namespaceSelector: + description: |- + A label query over the set of namespaces that the term applies to. + The term is applied to the union of the namespaces selected by this field + and the ones listed in the namespaces field. + null selector and null or empty namespaces list means "this pod's namespace". + An empty selector ({}) matches all namespaces. + properties: + matchExpressions: + description: matchExpressions is a list of label selector requirements. The requirements are ANDed. + items: + description: |- + A label selector requirement is a selector that contains values, a key, and an operator that + relates the key and values. + properties: + key: + description: key is the label key that the selector applies to. + type: string + operator: + description: |- + operator represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists and DoesNotExist. + type: string + values: + description: |- + values is an array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. This array is replaced during a strategic + merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + matchLabels: + additionalProperties: + type: string + description: |- + matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels + map is equivalent to an element of matchExpressions, whose key field is "key", the + operator is "In", and the values array contains only "value". The requirements are ANDed. + type: object + type: object + x-kubernetes-map-type: atomic + namespaces: + description: |- + namespaces specifies a static list of namespace names that the term applies to. + The term is applied to the union of the namespaces listed in this field + and the ones selected by namespaceSelector. + null or empty namespaces list and null namespaceSelector means "this pod's namespace". + items: + type: string + type: array + x-kubernetes-list-type: atomic + topologyKey: + description: |- + This pod should be co-located (affinity) or not co-located (anti-affinity) with the pods matching + the labelSelector in the specified namespaces, where co-located is defined as running on a node + whose value of the label with key topologyKey matches that of any node on which any of the + selected pods is running. + Empty topologyKey is not allowed. + type: string + required: + - topologyKey + type: object + weight: + description: |- + weight associated with matching the corresponding podAffinityTerm, + in the range 1-100. + format: int32 + type: integer + required: + - podAffinityTerm + - weight + type: object + type: array + x-kubernetes-list-type: atomic + requiredDuringSchedulingIgnoredDuringExecution: + description: |- + If the anti-affinity requirements specified by this field are not met at + scheduling time, the pod will not be scheduled onto the node. + If the anti-affinity requirements specified by this field cease to be met + at some point during pod execution (e.g. due to a pod label update), the + system may or may not try to eventually evict the pod from its node. + When there are multiple elements, the lists of nodes corresponding to each + podAffinityTerm are intersected, i.e. all terms must be satisfied. + items: + description: |- + Defines a set of pods (namely those matching the labelSelector + relative to the given namespace(s)) that this pod should be + co-located (affinity) or not co-located (anti-affinity) with, + where co-located is defined as running on a node whose value of + the label with key matches that of any node on which + a pod of the set of pods is running + properties: + labelSelector: + description: |- + A label query over a set of resources, in this case pods. + If it's null, this PodAffinityTerm matches with no Pods. + properties: + matchExpressions: + description: matchExpressions is a list of label selector requirements. The requirements are ANDed. + items: + description: |- + A label selector requirement is a selector that contains values, a key, and an operator that + relates the key and values. + properties: + key: + description: key is the label key that the selector applies to. + type: string + operator: + description: |- + operator represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists and DoesNotExist. + type: string + values: + description: |- + values is an array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. This array is replaced during a strategic + merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + matchLabels: + additionalProperties: + type: string + description: |- + matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels + map is equivalent to an element of matchExpressions, whose key field is "key", the + operator is "In", and the values array contains only "value". The requirements are ANDed. + type: object + type: object + x-kubernetes-map-type: atomic + matchLabelKeys: + description: |- + MatchLabelKeys is a set of pod label keys to select which pods will + be taken into consideration. The keys are used to lookup values from the + incoming pod labels, those key-value labels are merged with `labelSelector` as `key in (value)` + to select the group of existing pods which pods will be taken into consideration + for the incoming pod's pod (anti) affinity. Keys that don't exist in the incoming + pod labels will be ignored. The default value is empty. + The same key is forbidden to exist in both matchLabelKeys and labelSelector. + Also, matchLabelKeys cannot be set when labelSelector isn't set. + items: + type: string + type: array + x-kubernetes-list-type: atomic + mismatchLabelKeys: + description: |- + MismatchLabelKeys is a set of pod label keys to select which pods will + be taken into consideration. The keys are used to lookup values from the + incoming pod labels, those key-value labels are merged with `labelSelector` as `key notin (value)` + to select the group of existing pods which pods will be taken into consideration + for the incoming pod's pod (anti) affinity. Keys that don't exist in the incoming + pod labels will be ignored. The default value is empty. + The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. + Also, mismatchLabelKeys cannot be set when labelSelector isn't set. + items: + type: string + type: array + x-kubernetes-list-type: atomic + namespaceSelector: + description: |- + A label query over the set of namespaces that the term applies to. + The term is applied to the union of the namespaces selected by this field + and the ones listed in the namespaces field. + null selector and null or empty namespaces list means "this pod's namespace". + An empty selector ({}) matches all namespaces. + properties: + matchExpressions: + description: matchExpressions is a list of label selector requirements. The requirements are ANDed. + items: + description: |- + A label selector requirement is a selector that contains values, a key, and an operator that + relates the key and values. + properties: + key: + description: key is the label key that the selector applies to. + type: string + operator: + description: |- + operator represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists and DoesNotExist. + type: string + values: + description: |- + values is an array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. This array is replaced during a strategic + merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + matchLabels: + additionalProperties: + type: string + description: |- + matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels + map is equivalent to an element of matchExpressions, whose key field is "key", the + operator is "In", and the values array contains only "value". The requirements are ANDed. + type: object + type: object + x-kubernetes-map-type: atomic + namespaces: + description: |- + namespaces specifies a static list of namespace names that the term applies to. + The term is applied to the union of the namespaces listed in this field + and the ones selected by namespaceSelector. + null or empty namespaces list and null namespaceSelector means "this pod's namespace". + items: + type: string + type: array + x-kubernetes-list-type: atomic + topologyKey: + description: |- + This pod should be co-located (affinity) or not co-located (anti-affinity) with the pods matching + the labelSelector in the specified namespaces, where co-located is defined as running on a node + whose value of the label with key topologyKey matches that of any node on which any of the + selected pods is running. + Empty topologyKey is not allowed. + type: string + required: + - topologyKey + type: object + type: array + x-kubernetes-list-type: atomic + type: object + type: object + nodeSelector: + additionalProperties: + type: string + description: If specified, applies node selector rules to the pods deployed by the DNSConfig resource. + type: object tolerations: description: If specified, applies tolerations to the pods deployed by the DNSConfig resource. items: diff --git a/cmd/k8s-operator/dnsrecords_test.go b/cmd/k8s-operator/dnsrecords_test.go index 0d89c4a863e4d..c6c5ee0296ca3 100644 --- a/cmd/k8s-operator/dnsrecords_test.go +++ b/cmd/k8s-operator/dnsrecords_test.go @@ -25,7 +25,6 @@ import ( tsapi "tailscale.com/k8s-operator/apis/v1alpha1" "tailscale.com/kube/kubetypes" "tailscale.com/tstest" - "tailscale.com/types/ptr" ) func TestDNSRecordsReconciler(t *testing.T) { @@ -44,7 +43,7 @@ func TestDNSRecordsReconciler(t *testing.T) { Namespace: "test", }, Spec: networkingv1.IngressSpec{ - IngressClassName: ptr.To("tailscale"), + IngressClassName: new("tailscale"), }, Status: networkingv1.IngressStatus{ LoadBalancer: networkingv1.IngressLoadBalancerStatus{ @@ -150,7 +149,7 @@ func TestDNSRecordsReconciler(t *testing.T) { // 7. A not-ready Endpoint is removed from DNS config. mustUpdate(t, fc, ep.Namespace, ep.Name, func(ep *discoveryv1.EndpointSlice) { - ep.Endpoints[0].Conditions.Ready = ptr.To(false) + ep.Endpoints[0].Conditions.Ready = new(false) ep.Endpoints = append(ep.Endpoints, discoveryv1.Endpoint{ Addresses: []string{"1.2.3.4"}, }) @@ -220,13 +219,13 @@ func TestDNSRecordsReconciler(t *testing.T) { Endpoints: []discoveryv1.Endpoint{{ Addresses: []string{"10.1.0.100", "10.1.0.101", "10.1.0.102"}, // Pod IPs that should NOT be used Conditions: discoveryv1.EndpointConditions{ - Ready: ptr.To(true), - Serving: ptr.To(true), - Terminating: ptr.To(false), + Ready: new(true), + Serving: new(true), + Terminating: new(false), }, }}, Ports: []discoveryv1.EndpointPort{{ - Port: ptr.To(int32(10443)), + Port: new(int32(10443)), }}, } @@ -316,7 +315,7 @@ func TestDNSRecordsReconcilerDualStack(t *testing.T) { Namespace: "test", }, Spec: networkingv1.IngressSpec{ - IngressClassName: ptr.To("tailscale"), + IngressClassName: new("tailscale"), }, Status: networkingv1.IngressStatus{ LoadBalancer: networkingv1.IngressLoadBalancerStatus{ @@ -447,9 +446,9 @@ func endpointSliceForService(svc *corev1.Service, ip string, fam discoveryv1.Add Endpoints: []discoveryv1.Endpoint{{ Addresses: []string{ip}, Conditions: discoveryv1.EndpointConditions{ - Ready: ptr.To(true), - Serving: ptr.To(true), - Terminating: ptr.To(false), + Ready: new(true), + Serving: new(true), + Terminating: new(false), }, }}, } diff --git a/cmd/k8s-operator/e2e/helpers.go b/cmd/k8s-operator/e2e/helpers.go new file mode 100644 index 0000000000000..e01821c2367e3 --- /dev/null +++ b/cmd/k8s-operator/e2e/helpers.go @@ -0,0 +1,31 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package e2e + +import ( + "crypto/rand" + "crypto/tls" + "fmt" + "net/http" + "strings" + "time" + + "tailscale.com/tsnet" +) + +func generateName(prefix string) string { + return fmt.Sprintf("%s-%s", prefix, strings.ToLower(rand.Text())) +} + +// newHTTPClient returns a HTTP client for the given tailnet client. +// When running against devcontrol, trusts Pebble testCAs. +func newHTTPClient(cl *tsnet.Server) *http.Client { + return &http.Client{ + Timeout: 10 * time.Second, + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{RootCAs: testCAs}, + DialContext: cl.Dial, + }, + } +} diff --git a/cmd/k8s-operator/e2e/ingress_test.go b/cmd/k8s-operator/e2e/ingress_test.go index 47a838414d449..bef24ca5a0a3b 100644 --- a/cmd/k8s-operator/e2e/ingress_test.go +++ b/cmd/k8s-operator/e2e/ingress_test.go @@ -5,80 +5,48 @@ package e2e import ( "context" - "encoding/json" "fmt" "net/http" + "strings" "testing" "time" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" + networkingv1 "k8s.io/api/networking/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/util/wait" - "k8s.io/client-go/kubernetes" "sigs.k8s.io/controller-runtime/pkg/client" - "sigs.k8s.io/yaml" - "tailscale.com/cmd/testwrapper/flakytest" + "tailscale.com/client/tailscale/v2" kube "tailscale.com/k8s-operator" + tsapi "tailscale.com/k8s-operator/apis/v1alpha1" + "tailscale.com/kube/kubetypes" + "tailscale.com/tsnet" "tailscale.com/tstest" - "tailscale.com/types/ptr" "tailscale.com/util/httpm" ) // See [TestMain] for test requirements. -func TestIngress(t *testing.T) { - flakytest.Mark(t, "https://github.com/tailscale/corp/issues/37533") +func TestL3Ingress(t *testing.T) { if tnClient == nil { - t.Skip("TestIngress requires a working tailnet client") + t.Skip("TestL3Ingress requires a working tailnet client") } // Apply nginx - createAndCleanup(t, kubeClient, - &appsv1.Deployment{ - ObjectMeta: metav1.ObjectMeta{ - Name: "nginx", - Namespace: "default", - Labels: map[string]string{ - "app.kubernetes.io/name": "nginx", - }, - }, - Spec: appsv1.DeploymentSpec{ - Replicas: ptr.To[int32](1), - Selector: &metav1.LabelSelector{ - MatchLabels: map[string]string{ - "app.kubernetes.io/name": "nginx", - }, - }, - Template: corev1.PodTemplateSpec{ - ObjectMeta: metav1.ObjectMeta{ - Labels: map[string]string{ - "app.kubernetes.io/name": "nginx", - }, - }, - Spec: corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "nginx", - Image: "nginx", - }, - }, - }, - }, - }, - }) + nginx := nginxDeployment(ns) + createAndCleanup(t, kubeClient, nginx) // Apply service to expose it as ingress svc := &corev1.Service{ ObjectMeta: metav1.ObjectMeta{ - Name: "test-ingress", - Namespace: "default", + Name: generateName("test-ingress"), + Namespace: ns, Annotations: map[string]string{ "tailscale.com/expose": "true", }, }, Spec: corev1.ServiceSpec{ Selector: map[string]string{ - "app.kubernetes.io/name": "nginx", + "app.kubernetes.io/name": nginx.Name, }, Ports: []corev1.ServicePort{ { @@ -91,100 +59,513 @@ func TestIngress(t *testing.T) { } createAndCleanup(t, kubeClient, svc) - // TODO(tomhjp): Delete once we've reproduced the flake with this extra info. - t0 := time.Now() - watcherCtx, cancelWatcher := context.WithCancel(t.Context()) - defer cancelWatcher() - go func() { - // client-go client for logs. - clientGoKubeClient, err := kubernetes.NewForConfig(restCfg) - if err != nil { - t.Logf("error creating client-go Kubernetes client: %v", err) - return + if err := tstest.WaitFor(time.Minute, func() error { + maybeReadySvc := &corev1.Service{ObjectMeta: objectMeta(ns, svc.Name)} + if err := get(t.Context(), kubeClient, maybeReadySvc); err != nil { + return err } + isReady := kube.SvcIsReady(maybeReadySvc) + if isReady { + t.Log("Service is ready") + return nil + } + return fmt.Errorf("Service is not ready yet") + }); err != nil { + t.Fatalf("error waiting for the Service to become Ready: %v", err) + } - for { - select { - case <-watcherCtx.Done(): - t.Logf("stopping watcher after %v", time.Since(t0)) - return - case <-time.After(time.Minute): - t.Logf("dumping info after %v elapsed", time.Since(t0)) - // Service itself. - svc := &corev1.Service{ObjectMeta: objectMeta("default", "test-ingress")} - err := get(watcherCtx, kubeClient, svc) - svcYaml, _ := yaml.Marshal(svc) - t.Logf("Service: %s, error: %v\n%s", svc.Name, err, string(svcYaml)) - - // Pods in tailscale namespace. - var pods corev1.PodList - if err := kubeClient.List(watcherCtx, &pods, client.InNamespace("tailscale")); err != nil { - t.Logf("error listing Pods in tailscale namespace: %v", err) - } else { - t.Logf("%d Pods", len(pods.Items)) - for _, pod := range pods.Items { - podYaml, _ := yaml.Marshal(pod) - t.Logf("Pod: %s\n%s", pod.Name, string(podYaml)) - logs := clientGoKubeClient.CoreV1().Pods("tailscale").GetLogs(pod.Name, &corev1.PodLogOptions{}).Do(watcherCtx) - logData, err := logs.Raw() - if err != nil { - t.Logf("error reading logs for Pod %s: %v", pod.Name, err) - continue - } - t.Logf("Logs for Pod %s:\n%s", pod.Name, string(logData)) - } - } + // Get the DNS name for the Service from the associated Secret. + var fqdn string + if err := tstest.WaitFor(time.Minute, func() error { + var secrets corev1.SecretList + if err := kubeClient.List(t.Context(), &secrets, + client.InNamespace("tailscale"), + client.MatchingLabels{ + "tailscale.com/parent-resource": svc.Name, + "tailscale.com/parent-resource-ns": ns, + }, + ); err != nil { + return err + } + if len(secrets.Items) == 0 { + return fmt.Errorf("Service not ready yet") + } + fqdn = strings.TrimSuffix(string(secrets.Items[0].Data[kubetypes.KeyDeviceFQDN]), ".") + if fqdn != "" { + t.Log("Got DNS name for Service") + return nil + } + return fmt.Errorf("device FQDN not set yet") + }); err != nil { + t.Fatalf("error waiting for DNS Name for Service: %v", err) + } + + if err := testIngressIsReachable(t, newHTTPClient(tnClient), fmt.Sprintf("http://%s:80", fqdn)); err != nil { + t.Fatal(err) + } +} + +func TestL3HAIngress(t *testing.T) { + if tnClient == nil { + t.Skip("TestL3HAIngress requires a working tailnet client") + } + + // Apply nginx. + nginx := nginxDeployment(ns) + createAndCleanup(t, kubeClient, nginx) + + // Create an ingress ProxyGroup. + pg := &tsapi.ProxyGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: generateName("ingress"), + }, + Spec: tsapi.ProxyGroupSpec{ + Type: tsapi.ProxyGroupTypeIngress, + }, + } + createAndCleanup(t, kubeClient, pg) + + // Apply a Service to expose nginx via the ProxyGroup. + svc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: generateName("test-ingress"), + Namespace: ns, + Annotations: map[string]string{ + "tailscale.com/proxy-group": pg.Name, + }, + }, + Spec: corev1.ServiceSpec{ + Type: corev1.ServiceTypeLoadBalancer, + LoadBalancerClass: new("tailscale"), + Selector: map[string]string{ + "app.kubernetes.io/name": nginx.Name, + }, + Ports: []corev1.ServicePort{ + { + Name: "http", + Protocol: "TCP", + Port: 80, + }, + }, + }, + } + createAndCleanup(t, kubeClient, svc) + + var svcIPv4 string + forceReconcile := triggerReconcile(t, + client.ObjectKey{Namespace: ns, Name: svc.Name}, + &corev1.Service{}, 30*time.Second) - // Tailscale status on the tailnet. - lc, err := tnClient.LocalClient() - if err != nil { - t.Logf("error getting tailnet local client: %v", err) - } else { - status, err := lc.Status(watcherCtx) - statusJSON, _ := json.MarshalIndent(status, "", " ") - t.Logf("Tailnet status: %s, error: %v", string(statusJSON), err) + // Wait for Service to be ready + if err := tstest.WaitFor(5*time.Minute, func() error { + maybeReadySvc := &corev1.Service{ObjectMeta: objectMeta(ns, svc.Name)} + forceReconcile() + if err := get(t.Context(), kubeClient, maybeReadySvc); err != nil { + return err + } + for _, cond := range maybeReadySvc.Status.Conditions { + if cond.Type == string(tsapi.IngressSvcConfigured) && cond.Status == metav1.ConditionTrue { + if len(maybeReadySvc.Status.LoadBalancer.Ingress) == 0 { + return fmt.Errorf("Service does not have an IP assigned yet") } + svcIPv4 = maybeReadySvc.Status.LoadBalancer.Ingress[0].IP + t.Log("Service is ready") + return nil } } - }() + return fmt.Errorf("Service is not ready yet") + }); err != nil { + t.Fatalf("error waiting for the Service to become ready: %v", err) + } + + if err := testIngressIsReachable(t, newHTTPClient(tnClient), fmt.Sprintf("http://%s:80", svcIPv4)); err != nil { + t.Fatal(err) + } +} + +func TestL7Ingress(t *testing.T) { + if tnClient == nil { + t.Skip("TestL7Ingress requires a working tailnet client") + } + + // Apply nginx Deployment and Service. + nginx := nginxDeployment(ns) + createAndCleanup(t, kubeClient, nginx) + createAndCleanup(t, kubeClient, &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: nginx.Name, + Namespace: ns, + }, + Spec: corev1.ServiceSpec{ + Selector: map[string]string{ + "app.kubernetes.io/name": nginx.Name, + }, + Ports: []corev1.ServicePort{ + { + Name: "http", + Port: 80, + }, + }, + }, + }) + + // Apply Ingress to expose nginx. + ingress := l7Ingress(ns, nginx.Name, map[string]string{}) + createAndCleanup(t, kubeClient, ingress) + + t.Log("Waiting for the Ingress to be ready...") + + hostname, err := waitForIngressHostname(t, ns, ingress.Name) + if err != nil { + t.Fatalf("error waiting for Ingress hostname: %v", err) + } + + if err := testIngressIsReachable(t, newHTTPClient(tnClient), fmt.Sprintf("https://%s:443", hostname)); err != nil { + t.Fatal(err) + } +} + +func TestL7HAIngress(t *testing.T) { + if tnClient == nil { + t.Skip("TestL7HAIngress requires a working tailnet client") + } + + // Apply nginx Deployment and Service. + nginx := nginxDeployment(ns) + createAndCleanup(t, kubeClient, nginx) + createAndCleanup(t, kubeClient, &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: nginx.Name, + Namespace: ns, + }, + Spec: corev1.ServiceSpec{ + Selector: map[string]string{ + "app.kubernetes.io/name": nginx.Name, + }, + Ports: []corev1.ServicePort{ + { + Name: "http", + Port: 80, + }, + }, + }, + }) + + // Create ProxyGroup that the Ingress will reference. + pg := &tsapi.ProxyGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: generateName("ingress"), + }, + Spec: tsapi.ProxyGroupSpec{ + Type: tsapi.ProxyGroupTypeIngress, + }, + } + createAndCleanup(t, kubeClient, pg) + + // Apply Ingress to expose nginx. + ingress := l7Ingress(ns, nginx.Name, map[string]string{"tailscale.com/proxy-group": pg.Name}) + createAndCleanup(t, kubeClient, ingress) + + t.Log("Waiting for the Ingress to be ready...") + + hostname, err := waitForIngressHostname(t, ns, ingress.Name) + if err != nil { + t.Fatalf("error waiting for Ingress hostname: %v", err) + } + + if err := testIngressIsReachable(t, newHTTPClient(tnClient), fmt.Sprintf("https://%s:443", hostname)); err != nil { + t.Fatal(err) + } +} + +func TestL7HAIngressMultiTailnet(t *testing.T) { + if tnClient == nil || secondTNClient == nil { + t.Skip("TestL7HAIngressMultiTailnet requires a working tailnet client for a first and second tailnet") + } + + // Apply nginx Deployment and Service. + nginx := nginxDeployment(ns) + createAndCleanup(t, kubeClient, nginx) + createAndCleanup(t, kubeClient, &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: nginx.Name, + Namespace: ns, + }, + Spec: corev1.ServiceSpec{ + Selector: map[string]string{ + "app.kubernetes.io/name": nginx.Name, + }, + Ports: []corev1.ServicePort{ + { + Name: "http", + Port: 80, + }, + }, + }, + }) + + // Create Ingress ProxyGroup for each Tailnet. + firstTailnetPG := &tsapi.ProxyGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: generateName("first-tailnet"), + }, + Spec: tsapi.ProxyGroupSpec{ + Type: tsapi.ProxyGroupTypeIngress, + }, + } + createAndCleanup(t, kubeClient, firstTailnetPG) + secondTailnetPG := &tsapi.ProxyGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: generateName("second-tailnet"), + }, + Spec: tsapi.ProxyGroupSpec{ + Type: tsapi.ProxyGroupTypeIngress, + Tailnet: "second-tailnet", + }, + } + createAndCleanup(t, kubeClient, secondTailnetPG) + + if err := verifyProxyGroupTailnet(t, firstTailnetPG, tnClient); err != nil { + t.Fatalf("verifying ProxyGroup %s is registered to the correct tailnet: %v", firstTailnetPG.Name, err) + } + if err := verifyProxyGroupTailnet(t, secondTailnetPG, secondTNClient); err != nil { + t.Fatalf("verifying ProxyGroup %s is registered to the correct tailnet: %v", secondTailnetPG.Name, err) + } + + // Apply Ingress to expose nginx. + ingress := l7Ingress(ns, nginx.Name, map[string]string{ + "tailscale.com/proxy-group": secondTailnetPG.Name, + }) + createAndCleanup(t, kubeClient, ingress) + + // Check that the tailscale (VIP) Service has been created in the expected Tailnet. + svcName := "svc:" + ingress.Name + if err := tstest.WaitFor(3*time.Minute, func() error { + _, err := secondTSClient.VIPServices().Get(t.Context(), svcName) + if tailscale.IsNotFound(err) { + return fmt.Errorf("Tailscale service %q not yet in expected tailnet", svcName) + } + return err + }); err != nil { + t.Fatalf("Tailscale service %q never appeared in expected tailnet: %v", svcName, err) + } + hostname, err := waitForIngressHostname(t, ns, ingress.Name) + if err != nil { + t.Fatalf("error waiting for Ingress hostname: %v", err) + } + if err := testIngressIsReachable(t, newHTTPClient(secondTNClient), fmt.Sprintf("https://%s:443", hostname)); err != nil { + t.Fatal(err) + } +} + +func l7Ingress(namespace, svc string, annotations map[string]string) *networkingv1.Ingress { + name := generateName("test-ingress") + ingress := &networkingv1.Ingress{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: namespace, + Annotations: annotations, + }, + Spec: networkingv1.IngressSpec{ + IngressClassName: new("tailscale"), + TLS: []networkingv1.IngressTLS{ + {Hosts: []string{name}}, + }, + Rules: []networkingv1.IngressRule{ + { + IngressRuleValue: networkingv1.IngressRuleValue{ + HTTP: &networkingv1.HTTPIngressRuleValue{ + Paths: []networkingv1.HTTPIngressPath{ + { + Path: "/", + PathType: new(networkingv1.PathTypePrefix), + Backend: networkingv1.IngressBackend{ + Service: &networkingv1.IngressServiceBackend{ + Name: svc, + Port: networkingv1.ServiceBackendPort{ + Number: 80, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + return ingress +} - // TODO: instead of timing out only when test times out, cancel context after 60s or so. - if err := wait.PollUntilContextCancel(t.Context(), time.Millisecond*100, true, func(ctx context.Context) (done bool, err error) { - if time.Since(t0) > time.Minute { - t.Logf("%v elapsed waiting for Service default/test-ingress to become Ready", time.Since(t0)) +func nginxDeployment(namespace string) *appsv1.Deployment { + name := generateName("nginx") + return &appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: namespace, + Labels: map[string]string{ + "app.kubernetes.io/name": name, + }, + }, + Spec: appsv1.DeploymentSpec{ + Replicas: new(int32(1)), + Selector: &metav1.LabelSelector{ + MatchLabels: map[string]string{ + "app.kubernetes.io/name": name, + }, + }, + Template: corev1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Labels: map[string]string{ + "app.kubernetes.io/name": name, + }, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "nginx", + Image: "nginx", + }, + }, + }, + }, + }, + } +} + +// triggerReconcile triggers an expected reconcile for the given object if +// none occurs. This is needed when running some tests against devcontrol, +// where the final change that should trigger a reconcile does not always do so. +// This has not been reproducible in a real tailnet environment, so a +// workaround that runs only when using devcontrol is acceptable. +func triggerReconcile(t testing.TB, key client.ObjectKey, obj client.Object, after time.Duration) func() { + if !*fDevcontrol { + return func() {} + } + triggerAt := time.Now().Add(after) + var triggered bool + return func() { + if triggered || !time.Now().After(triggerAt) { + return } - maybeReadySvc := &corev1.Service{ObjectMeta: objectMeta("default", "test-ingress")} - if err := get(ctx, kubeClient, maybeReadySvc); err != nil { - return false, err + if err := kubeClient.Get(t.Context(), key, obj); err != nil { + t.Logf("failed to get %s: %v", key, err) + return } - isReady := kube.SvcIsReady(maybeReadySvc) - if isReady { - t.Log("Service is ready") + ann := obj.GetAnnotations() + if ann == nil { + ann = map[string]string{} } - return isReady, nil - }); err != nil { - t.Fatalf("error waiting for the Service to become Ready: %v", err) + ann["tailscale.com/trigger-reconcile"] = "true" + obj.SetAnnotations(ann) + if err := kubeClient.Update(t.Context(), obj); err != nil { + t.Logf("failed to update %s: %v", key, err) + return + } + triggered = true } - cancelWatcher() +} +func testIngressIsReachable(t *testing.T, httpClient *http.Client, url string) error { + t.Helper() var resp *http.Response if err := tstest.WaitFor(time.Minute, func() error { - // TODO(tomhjp): Get the tailnet DNS name from the associated secret instead. - // If we are not the first tailnet node with the requested name, we'll get - // a -N suffix. - req, err := http.NewRequest(httpm.GET, fmt.Sprintf("http://%s-%s:80", svc.Namespace, svc.Name), nil) + req, err := http.NewRequest(httpm.GET, url, nil) if err != nil { return err } - ctx, cancel := context.WithTimeout(t.Context(), time.Second) + ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) defer cancel() - resp, err = tnClient.HTTPClient().Do(req.WithContext(ctx)) + resp, err = httpClient.Do(req.WithContext(ctx)) + if err != nil { + return err + } + resp.Body.Close() + return nil + }); err != nil { + return fmt.Errorf("error trying to reach %s: %w", url, err) + } + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("unexpected status from %s: %d", url, resp.StatusCode) + } + return nil +} + +// verifyProxyGroupTailnet verifies that a ProxyGroup is registered to the correct tailnet. +// This is done by getting the expected tailnet domain for the tailnet client, +// and comparing this with the actual device fqdn in the ProxyGroup state secret. +func verifyProxyGroupTailnet(t *testing.T, pg *tsapi.ProxyGroup, cl *tsnet.Server) error { + t.Helper() + // Determine the expected tailnet Magic DNS Name. + lc, err := cl.LocalClient() + if err != nil { return err + } + status, err := lc.Status(t.Context()) + if err != nil { + return err + } + _, expectedTailnet, ok := strings.Cut(strings.TrimSuffix(status.Self.DNSName, "."), ".") + if !ok { + return fmt.Errorf("unexpected DNSName format %q", status.Self.DNSName) + } + // Read the device FQDN from the first state secret for the ProxyGroup, + // and verify that this matches the expected tailnet. + if err := tstest.WaitFor(3*time.Minute, func() error { + var secrets corev1.SecretList + if err := kubeClient.List(t.Context(), &secrets, + client.InNamespace("tailscale"), + client.MatchingLabels{ + kubetypes.LabelSecretType: kubetypes.LabelSecretTypeState, + "tailscale.com/parent-resource-type": "proxygroup", + "tailscale.com/parent-resource": pg.Name, + }, + ); err != nil { + return err + } + if len(secrets.Items) == 0 { + return fmt.Errorf("no state secrets found for ProxyGroup %q yet", pg.Name) + } + fqdn := strings.TrimSuffix(string(secrets.Items[0].Data[kubetypes.KeyDeviceFQDN]), ".") + _, tailnet, ok := strings.Cut(fqdn, ".") + if !ok { + return fmt.Errorf("ProxyGroup %q: device FQDN %q has no domain yet", pg.Name, fqdn) + } + if tailnet != expectedTailnet { + return fmt.Errorf("ProxyGroup %q on wrong tailnet: got domain %q, want %q", pg.Name, tailnet, expectedTailnet) + } + return nil }); err != nil { - t.Fatalf("error trying to reach Service: %v", err) + return fmt.Errorf("ProxyGroup %q not on expected tailnet: %v", pg.Name, err) } + return nil +} - if resp.StatusCode != http.StatusOK { - t.Fatalf("unexpected status: %v; response body s", resp.StatusCode) +func waitForIngressHostname(t *testing.T, namespace, name string) (string, error) { + t.Helper() + var hostname string + forceReconcile := triggerReconcile(t, + client.ObjectKey{Namespace: namespace, Name: name}, + &networkingv1.Ingress{}, 30*time.Second) + + if err := tstest.WaitFor(5*time.Minute, func() error { + forceReconcile() + ing := &networkingv1.Ingress{} + if err := kubeClient.Get(t.Context(), client.ObjectKey{ + Namespace: namespace, Name: name, + }, ing); err != nil { + return err + } + if len(ing.Status.LoadBalancer.Ingress) == 0 || + ing.Status.LoadBalancer.Ingress[0].Hostname == "" { + return fmt.Errorf("Ingress not ready yet") + } + hostname = ing.Status.LoadBalancer.Ingress[0].Hostname + t.Log("Ingress is ready") + return nil + }); err != nil { + return "", fmt.Errorf("Ingress %s/%s never got a hostname: %w", namespace, name, err) } + return hostname, nil } diff --git a/cmd/k8s-operator/e2e/main_test.go b/cmd/k8s-operator/e2e/main_test.go index 02f614014dbee..9eab9e30157aa 100644 --- a/cmd/k8s-operator/e2e/main_test.go +++ b/cmd/k8s-operator/e2e/main_test.go @@ -54,7 +54,7 @@ func createAndCleanup(t *testing.T, cl client.Client, obj client.Object) { t.Cleanup(func() { // Use context.Background() for cleanup, as t.Context() is cancelled // just before cleanup functions are called. - if err = cl.Delete(context.Background(), obj); err != nil { + if err := cl.Delete(context.Background(), obj); err != nil { t.Errorf("error cleaning up %s %s/%s: %s", obj.GetObjectKind().GroupVersionKind(), obj.GetNamespace(), obj.GetName(), err) } }) @@ -69,7 +69,7 @@ func createAndCleanupErr(t *testing.T, cl client.Client, obj client.Object) erro } t.Cleanup(func() { - if err = cl.Delete(context.Background(), obj); err != nil { + if err := cl.Delete(context.Background(), obj); err != nil { t.Errorf("error cleaning up %s %s/%s: %s", obj.GetObjectKind().GroupVersionKind(), obj.GetNamespace(), obj.GetName(), err) } }) diff --git a/cmd/k8s-operator/e2e/pebble.go b/cmd/k8s-operator/e2e/pebble.go index 5fcb35e057c3d..7abe3416ef7dc 100644 --- a/cmd/k8s-operator/e2e/pebble.go +++ b/cmd/k8s-operator/e2e/pebble.go @@ -12,8 +12,6 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/intstr" "sigs.k8s.io/controller-runtime/pkg/client" - - "tailscale.com/types/ptr" ) func applyPebbleResources(ctx context.Context, cl client.Client) error { @@ -46,7 +44,7 @@ func pebbleDeployment(tag string) *appsv1.Deployment { Namespace: ns, }, Spec: appsv1.DeploymentSpec{ - Replicas: ptr.To[int32](1), + Replicas: new(int32(1)), Selector: &metav1.LabelSelector{ MatchLabels: map[string]string{ "app": "pebble", diff --git a/cmd/k8s-operator/e2e/proxy_test.go b/cmd/k8s-operator/e2e/proxy_test.go index 2d4fa53cc2589..3caf1c91d8bc9 100644 --- a/cmd/k8s-operator/e2e/proxy_test.go +++ b/cmd/k8s-operator/e2e/proxy_test.go @@ -4,10 +4,8 @@ package e2e import ( - "crypto/tls" "encoding/json" "fmt" - "net/http" "testing" "time" @@ -61,15 +59,7 @@ func TestProxy(t *testing.T) { Host: fmt.Sprintf("https://%s:443", hostNameFromOperatorSecret(t, operatorSecret)), } proxyCl, err := client.New(proxyCfg, client.Options{ - HTTPClient: &http.Client{ - Timeout: 10 * time.Second, - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{ - RootCAs: testCAs, - }, - DialContext: tnClient.Dial, - }, - }, + HTTPClient: newHTTPClient(tnClient), }) if err != nil { t.Fatal(err) diff --git a/cmd/k8s-operator/e2e/proxygrouppolicy_test.go b/cmd/k8s-operator/e2e/proxygrouppolicy_test.go index f8126499b0db0..0e73394d539da 100644 --- a/cmd/k8s-operator/e2e/proxygrouppolicy_test.go +++ b/cmd/k8s-operator/e2e/proxygrouppolicy_test.go @@ -13,7 +13,6 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" - "tailscale.com/types/ptr" ) // See [TestMain] for test requirements. @@ -82,7 +81,7 @@ func TestProxyGroupPolicy(t *testing.T) { }, Spec: corev1.ServiceSpec{ Type: corev1.ServiceTypeLoadBalancer, - LoadBalancerClass: ptr.To("tailscale"), + LoadBalancerClass: new("tailscale"), Ports: []corev1.ServicePort{ { Port: 8080, @@ -112,7 +111,7 @@ func TestProxyGroupPolicy(t *testing.T) { }, }, Spec: networkingv1.IngressSpec{ - IngressClassName: ptr.To("tailscale"), + IngressClassName: new("tailscale"), DefaultBackend: &networkingv1.IngressBackend{ Service: &networkingv1.IngressServiceBackend{ Name: "nginx", diff --git a/cmd/k8s-operator/e2e/setup.go b/cmd/k8s-operator/e2e/setup.go index c4fd45d3e4125..0d4ca80ad68f9 100644 --- a/cmd/k8s-operator/e2e/setup.go +++ b/cmd/k8s-operator/e2e/setup.go @@ -40,6 +40,7 @@ import ( "helm.sh/helm/v3/pkg/release" "helm.sh/helm/v3/pkg/storage/driver" corev1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/watch" "k8s.io/client-go/rest" @@ -53,11 +54,13 @@ import ( "sigs.k8s.io/kind/pkg/cluster/nodeutils" "sigs.k8s.io/kind/pkg/cmd" - "tailscale.com/internal/client/tailscale" + "tailscale.com/client/tailscale/v2" "tailscale.com/ipn" "tailscale.com/ipn/store/mem" + tsoperator "tailscale.com/k8s-operator" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" "tailscale.com/tsnet" + "tailscale.com/util/must" ) const ( @@ -68,10 +71,13 @@ const ( ) var ( - tsClient *tailscale.Client // For API calls to control. - tnClient *tsnet.Server // For testing real tailnet traffic. - restCfg *rest.Config // For constructing a client-go client if necessary. - kubeClient client.WithWatch // For k8s API calls. + tsClient *tailscale.Client // For API calls to control. + tnClient *tsnet.Server // For testing real tailnet traffic on first tailnet. + secondTSClient *tailscale.Client // For API calls to the secondary tailnet (_second_tailnet). + secondTNClient *tsnet.Server // For testing real tailnet traffic on second tailnet. + restCfg *rest.Config // For constructing a client-go client if necessary. + kubeClient client.WithWatch // For k8s API calls. + clusterLoginServer string //go:embed certs/pebble.minica.crt pebbleMiniCACert []byte @@ -105,7 +111,8 @@ func runTests(m *testing.M) (int, error) { if err != nil { return 0, err } - if err := os.MkdirAll(tmp, 0755); err != nil { + + if err = os.MkdirAll(tmp, 0755); err != nil { return 0, fmt.Errorf("failed to create temp dir: %w", err) } @@ -121,15 +128,17 @@ func runTests(m *testing.M) (int, error) { kindProvider = cluster.NewProvider( cluster.ProviderWithLogger(cmd.NewLogger()), ) + clusters, err := kindProvider.List() if err != nil { return 0, fmt.Errorf("failed to list kind clusters: %w", err) } + if !slices.Contains(clusters, kindClusterName) { if err := kindProvider.Create(kindClusterName, cluster.CreateWithWaitForReady(5*time.Minute), cluster.CreateWithKubeconfigPath(kubeconfig), - cluster.CreateWithNodeImage("kindest/node:v1.30.0"), + cluster.CreateWithNodeImage("kindest/node:v1.35.0"), ); err != nil { return 0, fmt.Errorf("failed to create kind cluster: %w", err) } @@ -146,34 +155,39 @@ func runTests(m *testing.M) (int, error) { if err != nil { return 0, fmt.Errorf("error loading kubeconfig: %w", err) } + kubeClient, err = client.NewWithWatch(restCfg, client.Options{Scheme: tsapi.GlobalScheme}) if err != nil { return 0, fmt.Errorf("error creating Kubernetes client: %w", err) } var ( - clusterLoginServer string // Login server from cluster Pod point of view. - clientID, clientSecret string // OAuth client for the operator to use. + clientID, clientSecret string // OAuth client for the first tailnet (for the operator to use). caPaths []string // Extra CA cert file paths to add to images. - certsDir string = filepath.Join(tmp, "certs") // Directory containing extra CA certs to add to images. + certsDir = filepath.Join(tmp, "certs") // Directory containing extra CA certs to add to images. + secondClientID, secondClientSecret string // OAuth client for the second tailnet (for the operator to use). ) if *fDevcontrol { // Deploy pebble and get its certs. - if err := applyPebbleResources(ctx, kubeClient); err != nil { + if err = applyPebbleResources(ctx, kubeClient); err != nil { return 0, fmt.Errorf("failed to apply pebble resources: %w", err) } + pebblePod, err := waitForPodReady(ctx, logger, kubeClient, ns, client.MatchingLabels{"app": "pebble"}) if err != nil { return 0, fmt.Errorf("pebble pod not ready: %w", err) } - if err := forwardLocalPortToPod(ctx, logger, restCfg, ns, pebblePod, 15000); err != nil { + + if err = forwardLocalPortToPod(ctx, logger, restCfg, ns, pebblePod, 15000); err != nil { return 0, fmt.Errorf("failed to set up port forwarding to pebble: %w", err) } + testCAs = x509.NewCertPool() if ok := testCAs.AppendCertsFromPEM(pebbleMiniCACert); !ok { return 0, fmt.Errorf("failed to parse pebble minica cert") } + var pebbleCAChain []byte for _, path := range []string{"/intermediates/0", "/roots/0"} { pem, err := pebbleGet(ctx, 15000, path) @@ -182,20 +196,25 @@ func runTests(m *testing.M) (int, error) { } pebbleCAChain = append(pebbleCAChain, pem...) } + if ok := testCAs.AppendCertsFromPEM(pebbleCAChain); !ok { return 0, fmt.Errorf("failed to parse pebble ca chain cert") } - if err := os.MkdirAll(certsDir, 0755); err != nil { + + if err = os.MkdirAll(certsDir, 0755); err != nil { return 0, fmt.Errorf("failed to create certs dir: %w", err) } + pebbleCAChainPath := filepath.Join(certsDir, "pebble-ca-chain.crt") - if err := os.WriteFile(pebbleCAChainPath, pebbleCAChain, 0644); err != nil { + if err = os.WriteFile(pebbleCAChainPath, pebbleCAChain, 0644); err != nil { return 0, fmt.Errorf("failed to write pebble CA chain: %w", err) } + pebbleMiniCACertPath := filepath.Join(certsDir, "pebble.minica.crt") - if err := os.WriteFile(pebbleMiniCACertPath, pebbleMiniCACert, 0644); err != nil { + if err = os.WriteFile(pebbleMiniCACertPath, pebbleMiniCACert, 0644); err != nil { return 0, fmt.Errorf("failed to write pebble minica: %w", err) } + caPaths = []string{pebbleCAChainPath, pebbleMiniCACertPath} if !*fSkipCleanup { defer os.RemoveAll(certsDir) @@ -209,13 +228,15 @@ func runTests(m *testing.M) (int, error) { // For Pods -> devcontrol (tailscale clients joining the tailnet): // * Create ssh-server Deployment in cluster. // * Create reverse ssh tunnel that goes from ssh-server port 31544 to localhost:31544. - if err := forwardLocalPortToPod(ctx, logger, restCfg, ns, pebblePod, 8055); err != nil { + if err = forwardLocalPortToPod(ctx, logger, restCfg, ns, pebblePod, 8055); err != nil { return 0, fmt.Errorf("failed to set up port forwarding to pebble: %w", err) } + privateKey, publicKey, err := readOrGenerateSSHKey(tmp) if err != nil { return 0, fmt.Errorf("failed to read or generate SSH key: %w", err) } + if !*fSkipCleanup { defer os.Remove(privateKeyPath) } @@ -224,6 +245,7 @@ func runTests(m *testing.M) (int, error) { if err != nil { return 0, fmt.Errorf("failed to set up cluster->devcontrol connection: %w", err) } + if !*fSkipCleanup { defer func() { if err := cleanupSSHResources(context.Background(), kubeClient); err != nil { @@ -244,7 +266,7 @@ func runTests(m *testing.M) (int, error) { var apiKeyData struct { APIKey string `json:"apiKey"` } - if err := json.Unmarshal(b, &apiKeyData); err != nil { + if err = json.Unmarshal(b, &apiKeyData); err != nil { return 0, fmt.Errorf("failed to parse api-key.json: %w", err) } if apiKeyData.APIKey == "" { @@ -252,75 +274,96 @@ func runTests(m *testing.M) (int, error) { } // Finish setting up tsClient. - tsClient = tailscale.NewClient("-", tailscale.APIKey(apiKeyData.APIKey)) - tsClient.BaseURL = "http://localhost:31544" + tsClient = &tailscale.Client{ + APIKey: apiKeyData.APIKey, + BaseURL: must.Get(url.Parse("http://localhost:31544")), + } // Set ACLs and create OAuth client. - req, _ := http.NewRequest("POST", tsClient.BuildTailnetURL("acl"), bytes.NewReader(requiredACLs)) - resp, err := tsClient.Do(req) - if err != nil { - return 0, fmt.Errorf("failed to set ACLs: %w", err) + if err = tsClient.PolicyFile().Set(ctx, string(requiredACLs), ""); err != nil { + return 0, fmt.Errorf("failed to set policy file: %w", err) } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - b, _ := io.ReadAll(resp.Body) - return 0, fmt.Errorf("HTTP %d setting ACLs: %s", resp.StatusCode, string(b)) - } - logger.Infof("ACLs configured") - reqBody, err := json.Marshal(map[string]any{ - "keyType": "client", - "scopes": []string{"auth_keys", "devices:core", "services"}, - "tags": []string{"tag:k8s-operator"}, - "description": "k8s-operator client for e2e tests", + logger.Info("ACLs configured for first tailnet") + + key, err := tsClient.Keys().CreateOAuthClient(ctx, tailscale.CreateOAuthClientRequest{ + Scopes: []string{"auth_keys", "devices:core", "services"}, + Tags: []string{"tag:k8s-operator"}, + Description: "k8s-operator client for e2e tests", }) if err != nil { - return 0, fmt.Errorf("failed to marshal OAuth client creation request: %w", err) + return 0, fmt.Errorf("failed to create OAuth client for first tailnet: %w", err) + } + clientID = key.ID + clientSecret = key.Key + + logger.Info("OAuth credentials set for first tailnet") + + // Create second tailnet. The bootstrap credentials returned have 'all' permissions- + // they are used for administrative actions and to create a separately scoped + // Oauth client for the k8s operator. + bootstrapClient, err := createTailnet(ctx, tsClient) + if err != nil { + return 0, fmt.Errorf("failed to create second tailnet: %w", err) } - req, _ = http.NewRequest("POST", tsClient.BuildTailnetURL("keys"), bytes.NewReader(reqBody)) - resp, err = tsClient.Do(req) + + // Set HTTPS on second tailnet. + err = bootstrapClient.TailnetSettings().Update(ctx, tailscale.UpdateTailnetSettingsRequest{HTTPSEnabled: new(true)}) if err != nil { - return 0, fmt.Errorf("failed to create OAuth client: %w", err) + return 0, fmt.Errorf("failed to configure https for second tailnet: %w", err) } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - b, _ := io.ReadAll(resp.Body) - return 0, fmt.Errorf("HTTP %d creating OAuth client: %s", resp.StatusCode, string(b)) + logger.Info("HTTPS settings configured for second tailnet") + + // Set ACLs for second tailnet. + if err = bootstrapClient.PolicyFile().Set(ctx, string(requiredACLs), ""); err != nil { + return 0, fmt.Errorf("failed to set policy file: %w", err) } - var key struct { - ID string `json:"id"` - Key string `json:"key"` + + logger.Info("ACLs configured for second tailnet") + + // Create an OAuth client for the second tailnet to be used + // by the k8s-operator. + secondKey, err := bootstrapClient.Keys().CreateOAuthClient(ctx, tailscale.CreateOAuthClientRequest{ + Scopes: []string{"auth_keys", "devices:core", "services"}, + Tags: []string{"tag:k8s-operator"}, + Description: "k8s-operator client for e2e tests", + }) + if err != nil { + return 0, fmt.Errorf("failed to create OAuth client for second tailnet: %w", err) } - if err := json.NewDecoder(resp.Body).Decode(&key); err != nil { - return 0, fmt.Errorf("failed to decode OAuth client creation response: %w", err) + secondClientID = secondKey.ID + secondClientSecret = secondKey.Key + + secondTSClient, err = tailscaleClientFromSecret(ctx, "http://localhost:31544", secondClientID, secondClientSecret) + if err != nil { + return 0, fmt.Errorf("failed to set up second tailnet client: %w", err) } - clientID = key.ID - clientSecret = key.Key + } else { clientSecret = os.Getenv("TS_API_CLIENT_SECRET") if clientSecret == "" { return 0, fmt.Errorf("must use --devcontrol or set TS_API_CLIENT_SECRET to an OAuth client suitable for the operator") } - // Format is "tskey-client--". - parts := strings.Split(clientSecret, "-") - if len(parts) != 4 { - return 0, fmt.Errorf("TS_API_CLIENT_SECRET is not valid") + clientID, err = clientIDFromSecret(clientSecret) + if err != nil { + return 0, fmt.Errorf("failed to get client id from secret: %w", err) } - clientID = parts[2] - credentials := clientcredentials.Config{ - ClientID: clientID, - ClientSecret: clientSecret, - TokenURL: fmt.Sprintf("%s/api/v2/oauth/token", ipn.DefaultControlURL), - Scopes: []string{"auth_keys"}, + tsClient, err = tailscaleClientFromSecret(ctx, ipn.DefaultControlURL, clientID, clientSecret) + if err != nil { + return 0, fmt.Errorf("failed to set up first tailnet client: %w", err) + } + secondClientSecret = os.Getenv("SECOND_TS_API_CLIENT_SECRET") + if secondClientSecret == "" { + return 0, fmt.Errorf("must use --devcontrol or set SECOND_TS_API_CLIENT_SECRET to an OAuth client suitable for the operator") + } + secondClientID, err = clientIDFromSecret(secondClientSecret) + if err != nil { + return 0, fmt.Errorf("failed to get client id from secret: %w", err) } - tk, err := credentials.Token(ctx) + secondTSClient, err = tailscaleClientFromSecret(ctx, ipn.DefaultControlURL, secondClientID, secondClientSecret) if err != nil { - return 0, fmt.Errorf("failed to get OAuth token: %w", err) + return 0, fmt.Errorf("failed to set up second tailnet client: %w", err) } - // An access token will last for an hour which is plenty of time for - // the tests to run. No need for token refresh logic. - tsClient = tailscale.NewClient("-", tailscale.APIKey(tk.AccessToken)) - tsClient.BaseURL = "http://localhost:31544" } var ossTag string @@ -388,6 +431,15 @@ func runTests(m *testing.M) (int, error) { if err != nil { return 0, fmt.Errorf("failed to load helm chart: %w", err) } + extraEnv := []map[string]any{ + { + "name": "K8S_PROXY_IMAGE", + "value": "local/k8s-proxy:" + ossTag, + }, + } + if *fDevcontrol { + extraEnv = append(extraEnv, map[string]any{"name": "TS_DEBUG_ACME_DIRECTORY_URL", "value": "https://pebble:14000/dir"}) + } values := map[string]any{ "loginServer": clusterLoginServer, "oauth": map[string]any{ @@ -398,17 +450,8 @@ func runTests(m *testing.M) (int, error) { "mode": "true", }, "operatorConfig": map[string]any{ - "logging": "debug", - "extraEnv": []map[string]any{ - { - "name": "K8S_PROXY_IMAGE", - "value": "local/k8s-proxy:" + ossTag, - }, - { - "name": "TS_DEBUG_ACME_DIRECTORY_URL", - "value": "https://pebble:14000/dir", - }, - }, + "logging": "debug", + "extraEnv": extraEnv, "image": map[string]any{ "repo": "local/k8s-operator", "tag": ossTag, @@ -438,7 +481,7 @@ func runTests(m *testing.M) (int, error) { return 0, fmt.Errorf("failed to install %q via helm: %w", relName, err) } - if err := applyDefaultProxyClass(ctx, kubeClient); err != nil { + if err := applyDefaultProxyClass(ctx, logger, kubeClient); err != nil { return 0, fmt.Errorf("failed to apply default ProxyClass: %w", err) } @@ -447,18 +490,24 @@ func runTests(m *testing.M) (int, error) { caps.Devices.Create.Ephemeral = true caps.Devices.Create.Tags = []string{"tag:k8s"} - authKey, authKeyMeta, err := tsClient.CreateKey(ctx, caps) + authKey, err := tsClient.Keys().CreateAuthKey(ctx, tailscale.CreateKeyRequest{Capabilities: caps}) if err != nil { - return 0, err + return 0, fmt.Errorf("failed to create auth key for first tailnet: %w", err) + } + defer tsClient.Keys().Delete(context.Background(), authKey.ID) + + secondAuthKey, err := secondTSClient.Keys().CreateAuthKey(ctx, tailscale.CreateKeyRequest{Capabilities: caps}) + if err != nil { + return 0, fmt.Errorf("failed to create auth key for second tailnet: %w", err) } - defer tsClient.DeleteKey(context.Background(), authKeyMeta.ID) + defer secondTSClient.Keys().Delete(context.Background(), secondAuthKey.ID) tnClient = &tsnet.Server{ - ControlURL: tsClient.BaseURL, + ControlURL: tsClient.BaseURL.String(), Hostname: "test-proxy", Ephemeral: true, Store: &mem.Store{}, - AuthKey: authKey, + AuthKey: authKey.Key, } _, err = tnClient.Up(ctx) if err != nil { @@ -466,9 +515,64 @@ func runTests(m *testing.M) (int, error) { } defer tnClient.Close() + secondTNClient = &tsnet.Server{ + ControlURL: secondTSClient.BaseURL.String(), + Hostname: "test-proxy", + Ephemeral: true, + Store: &mem.Store{}, + AuthKey: secondAuthKey.Key, + } + _, err = secondTNClient.Up(ctx) + if err != nil { + return 0, err + } + defer secondTNClient.Close() + + // Create the tailnet Secret in the tailscale namespace. + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "second-tailnet-credentials", + Namespace: "tailscale", + }, + Data: map[string][]byte{ + "client_id": []byte(secondClientID), + "client_secret": []byte(secondClientSecret), + }, + } + if err := createOrUpdate(ctx, kubeClient, secret); err != nil { + return 0, fmt.Errorf("failed to create second-tailnet-credentials Secret: %w", err) + } + defer kubeClient.Delete(context.Background(), secret) + + // Create the Tailnet resource. + tn := &tsapi.Tailnet{ + ObjectMeta: metav1.ObjectMeta{ + Name: "second-tailnet", + }, + Spec: tsapi.TailnetSpec{ + LoginURL: clusterLoginServer, + Credentials: tsapi.TailnetCredentials{ + SecretName: "second-tailnet-credentials", + }, + }, + } + if err := createOrUpdate(ctx, kubeClient, tn); err != nil { + return 0, fmt.Errorf("failed to create second-tailnet Tailnet: %w", err) + } + defer kubeClient.Delete(context.Background(), tn) + return m.Run(), nil } +func clientIDFromSecret(clientSecret string) (string, error) { + // Format is "tskey-client--". + parts := strings.Split(clientSecret, "-") + if len(parts) != 4 { + return "", fmt.Errorf("secret is not valid") + } + return parts[2], nil +} + func upgraderOrInstaller(cfg *action.Configuration, releaseName string) helmInstallerFunc { hist := action.NewHistory(cfg) hist.Max = 1 @@ -537,7 +641,16 @@ func tagForRepo(dir string) (string, error) { return tag, nil } -func applyDefaultProxyClass(ctx context.Context, cl client.Client) error { +func applyDefaultProxyClass(ctx context.Context, logger *zap.SugaredLogger, cl client.Client) error { + var env []tsapi.Env + if *fDevcontrol { + env = []tsapi.Env{ + { + Name: "TS_DEBUG_ACME_DIRECTORY_URL", + Value: "https://pebble:14000/dir", + }, + } + } pc := &tsapi.ProxyClass{ TypeMeta: metav1.TypeMeta{ APIVersion: tsapi.SchemeGroupVersion.String(), @@ -554,6 +667,7 @@ func applyDefaultProxyClass(ctx context.Context, cl client.Client) error { }, TailscaleContainer: &tsapi.Container{ ImagePullPolicy: "IfNotPresent", + Env: env, }, }, }, @@ -565,6 +679,24 @@ func applyDefaultProxyClass(ctx context.Context, cl client.Client) error { return fmt.Errorf("failed to apply default ProxyClass: %w", err) } + // Wait for the ProxyClass to be marked ready. + ctx, cancel := context.WithTimeout(ctx, time.Minute) + defer cancel() + for { + if err := cl.Get(ctx, client.ObjectKeyFromObject(pc), pc); err != nil { + return fmt.Errorf("failed to get default ProxyClass: %w", err) + } + if tsoperator.ProxyClassIsReady(pc) { + break + } + logger.Info("waiting for default ProxyClass to be ready...") + select { + case <-ctx.Done(): + return fmt.Errorf("timeout waiting for default ProxyClass to be ready") + case <-time.After(time.Second): + } + } + return nil } @@ -699,3 +831,65 @@ func buildImage(ctx context.Context, dir, repo, target, tag string, extraCACerts return nil } + +func createOrUpdate(ctx context.Context, cl client.Client, obj client.Object) error { + if err := cl.Create(ctx, obj); err != nil { + if !apierrors.IsAlreadyExists(err) { + return err + } + return cl.Update(ctx, obj) + } + return nil +} + +// createTailnet creates a new tailnet and returns a tailscale.Client +// authenticated against it using the bootstrap credentials included in the +// creation response. +func createTailnet(ctx context.Context, tsClient *tailscale.Client) (*tailscale.Client, error) { + tailnetName := fmt.Sprintf("second-tailnet-%d", time.Now().Unix()) + body, err := json.Marshal(map[string]any{"displayName": tailnetName}) + if err != nil { + return nil, fmt.Errorf("failed to marshal tailnet creation request: %w", err) + } + // TODO(beckypauley): change to use a method on tailscale.Client once this is available. + req, _ := http.NewRequestWithContext(ctx, "POST", tsClient.BaseURL.String()+"/api/v2/organizations/-/tailnets", bytes.NewBuffer(body)) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", tsClient.APIKey)) + resp, err := tsClient.HTTP.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to create tailnet: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + b, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("HTTP %d creating tailnet: %s", resp.StatusCode, string(b)) + } + var result struct { + OauthClient struct { + ID string `json:"id"` + Secret string `json:"secret"` + } `json:"oauthClient"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + return tailscaleClientFromSecret(ctx, tsClient.BaseURL.String(), result.OauthClient.ID, result.OauthClient.Secret) +} + +// tailscaleClientFromSecret exchanges OAuth client credentials for an access token and +// returns a tailscale.Client configured to use it. The token is valid for +// one hour, which is sufficient for the tests to run. No need for refresh logic. +func tailscaleClientFromSecret(ctx context.Context, baseURL, clientID, clientSecret string) (*tailscale.Client, error) { + cfg := clientcredentials.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + TokenURL: fmt.Sprintf("%s/api/v2/oauth/token", baseURL), + } + tk, err := cfg.Token(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get OAuth token for client %q: %w", clientID, err) + } + return &tailscale.Client{ + APIKey: tk.AccessToken, + BaseURL: must.Get(url.Parse(baseURL)), + }, nil +} diff --git a/cmd/k8s-operator/e2e/ssh.go b/cmd/k8s-operator/e2e/ssh.go index 371c44f9d4544..9adcce6e3eee0 100644 --- a/cmd/k8s-operator/e2e/ssh.go +++ b/cmd/k8s-operator/e2e/ssh.go @@ -26,7 +26,6 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" tailscaleroot "tailscale.com" - "tailscale.com/types/ptr" ) const ( @@ -206,7 +205,7 @@ func applySSHResources(ctx context.Context, cl client.Client, alpineTag string, func cleanupSSHResources(ctx context.Context, cl client.Client) error { noGrace := &client.DeleteOptions{ - GracePeriodSeconds: ptr.To[int64](0), + GracePeriodSeconds: new(int64(0)), } if err := cl.Delete(ctx, sshDeployment("", nil), noGrace); err != nil { return fmt.Errorf("failed to delete ssh-server Deployment: %w", err) @@ -232,7 +231,7 @@ func sshDeployment(tag string, pubKey []byte) *appsv1.Deployment { Namespace: ns, }, Spec: appsv1.DeploymentSpec{ - Replicas: ptr.To[int32](1), + Replicas: new(int32(1)), Selector: &metav1.LabelSelector{ MatchLabels: map[string]string{ "app": "ssh-server", diff --git a/cmd/k8s-operator/egress-eps.go b/cmd/k8s-operator/egress-eps.go index 5181edf49a26c..9f8510165bea9 100644 --- a/cmd/k8s-operator/egress-eps.go +++ b/cmd/k8s-operator/egress-eps.go @@ -20,8 +20,8 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/reconcile" + "tailscale.com/kube/egressservices" - "tailscale.com/types/ptr" ) // egressEpsReconciler reconciles EndpointSlices for tailnet services exposed to cluster via egress ProxyGroup proxies. @@ -91,7 +91,7 @@ func (er *egressEpsReconciler) Reconcile(ctx context.Context, req reconcile.Requ lg.Debugf("No egress config found, likely because ProxyGroup has not been created") return res, nil } - cfg, ok := (*cfgs)[tailnetSvc] + cfg, ok := cfgs[tailnetSvc] if !ok { lg.Infof("[unexpected] configuration for tailnet service %s not found", tailnetSvc) return res, nil @@ -120,9 +120,9 @@ func (er *egressEpsReconciler) Reconcile(ctx context.Context, req reconcile.Requ Hostname: (*string)(&pod.UID), Addresses: []string{podIP}, Conditions: discoveryv1.EndpointConditions{ - Ready: ptr.To(true), - Serving: ptr.To(true), - Terminating: ptr.To(false), + Ready: new(true), + Serving: new(true), + Terminating: new(false), }, }) } diff --git a/cmd/k8s-operator/egress-eps_test.go b/cmd/k8s-operator/egress-eps_test.go index 47acb64f27458..6335b4eb8454b 100644 --- a/cmd/k8s-operator/egress-eps_test.go +++ b/cmd/k8s-operator/egress-eps_test.go @@ -11,7 +11,6 @@ import ( "math/rand/v2" "testing" - "github.com/AlekSi/pointer" "go.uber.org/zap" corev1 "k8s.io/api/core/v1" discoveryv1 "k8s.io/api/discovery/v1" @@ -106,11 +105,11 @@ func TestTailscaleEgressEndpointSlices(t *testing.T) { expectReconciled(t, er, "operator-ns", "foo") eps.Endpoints = append(eps.Endpoints, discoveryv1.Endpoint{ Addresses: []string{"10.0.0.1"}, - Hostname: pointer.To("foo"), + Hostname: new("foo"), Conditions: discoveryv1.EndpointConditions{ - Serving: pointer.ToBool(true), - Ready: pointer.ToBool(true), - Terminating: pointer.ToBool(false), + Serving: new(true), + Ready: new(true), + Terminating: new(false), }, }) expectEqual(t, fc, eps) diff --git a/cmd/k8s-operator/egress-pod-readiness_test.go b/cmd/k8s-operator/egress-pod-readiness_test.go index baa1442671907..0cf9108f5cd20 100644 --- a/cmd/k8s-operator/egress-pod-readiness_test.go +++ b/cmd/k8s-operator/egress-pod-readiness_test.go @@ -24,7 +24,6 @@ import ( tsapi "tailscale.com/k8s-operator/apis/v1alpha1" "tailscale.com/kube/kubetypes" "tailscale.com/tstest" - "tailscale.com/types/ptr" ) func TestEgressPodReadiness(t *testing.T) { @@ -48,7 +47,7 @@ func TestEgressPodReadiness(t *testing.T) { }, Spec: tsapi.ProxyGroupSpec{ Type: "egress", - Replicas: ptr.To(int32(3)), + Replicas: new(int32(3)), }, } mustCreate(t, fc, pg) diff --git a/cmd/k8s-operator/egress-services-readiness_test.go b/cmd/k8s-operator/egress-services-readiness_test.go index ba89903df2f29..96d76cc4e7252 100644 --- a/cmd/k8s-operator/egress-services-readiness_test.go +++ b/cmd/k8s-operator/egress-services-readiness_test.go @@ -9,7 +9,6 @@ import ( "fmt" "testing" - "github.com/AlekSi/pointer" "go.uber.org/zap" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" @@ -145,9 +144,9 @@ func setEndpointForReplica(pg *tsapi.ProxyGroup, ordinal int32, eps *discoveryv1 eps.Endpoints = append(eps.Endpoints, discoveryv1.Endpoint{ Addresses: []string{p.Status.PodIPs[0].IP}, Conditions: discoveryv1.EndpointConditions{ - Ready: pointer.ToBool(true), - Serving: pointer.ToBool(true), - Terminating: pointer.ToBool(false), + Ready: new(true), + Serving: new(true), + Terminating: new(false), }, }) } diff --git a/cmd/k8s-operator/egress-services.go b/cmd/k8s-operator/egress-services.go index 90ab2c88270ee..b9a3f8eaba799 100644 --- a/cmd/k8s-operator/egress-services.go +++ b/cmd/k8s-operator/egress-services.go @@ -30,6 +30,7 @@ import ( "k8s.io/client-go/tools/record" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/reconcile" + tsoperator "tailscale.com/k8s-operator" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" "tailscale.com/kube/egressservices" @@ -347,11 +348,11 @@ func (esr *egressSvcsReconciler) provision(ctx context.Context, proxyGroupName s return nil, false, nil } tailnetSvc := tailnetSvcName(svc) - gotCfg := (*cfgs)[tailnetSvc] + gotCfg := cfgs[tailnetSvc] wantsCfg := egressSvcCfg(svc, clusterIPSvc, esr.tsNamespace, lg) if !reflect.DeepEqual(gotCfg, wantsCfg) { lg.Debugf("updating egress services ConfigMap %s", cm.Name) - mak.Set(cfgs, tailnetSvc, wantsCfg) + mak.Set(&cfgs, tailnetSvc, wantsCfg) bs, err := json.Marshal(cfgs) if err != nil { return nil, false, fmt.Errorf("error marshalling egress services configs: %w", err) @@ -457,7 +458,8 @@ func (esr *egressSvcsReconciler) clusterIPSvcForEgress(crl map[string]string) *c Labels: crl, }, Spec: corev1.ServiceSpec{ - Type: corev1.ServiceTypeClusterIP, + Type: corev1.ServiceTypeClusterIP, + IPFamilyPolicy: new(corev1.IPFamilyPolicyPreferDualStack), }, } } @@ -484,19 +486,19 @@ func (esr *egressSvcsReconciler) ensureEgressSvcCfgDeleted(ctx context.Context, lggr.Debugf("ConfigMap does not contain egress service configs") return nil } - cfgs := &egressservices.Configs{} - if err := json.Unmarshal(bs, cfgs); err != nil { + cfgs := egressservices.Configs{} + if err := json.Unmarshal(bs, &cfgs); err != nil { return fmt.Errorf("error unmarshalling egress services configs") } tailnetSvc := tailnetSvcName(svc) - _, ok := (*cfgs)[tailnetSvc] + _, ok := cfgs[tailnetSvc] if !ok { lggr.Debugf("ConfigMap does not contain egress service config, likely because it was already deleted") return nil } - lggr.Infof("before deleting config %+#v", *cfgs) - delete(*cfgs, tailnetSvc) - lggr.Infof("after deleting config %+#v", *cfgs) + lggr.Infof("before deleting config %+#v", cfgs) + delete(cfgs, tailnetSvc) + lggr.Infof("after deleting config %+#v", cfgs) bs, err := json.Marshal(cfgs) if err != nil { return fmt.Errorf("error marshalling egress services configs: %w", err) @@ -648,7 +650,7 @@ func isEgressSvcForProxyGroup(obj client.Object) bool { // egressSvcConfig returns a ConfigMap that contains egress services configuration for the provided ProxyGroup as well // as unmarshalled configuration from the ConfigMap. -func egressSvcsConfigs(ctx context.Context, cl client.Client, proxyGroupName, tsNamespace string) (cm *corev1.ConfigMap, cfgs *egressservices.Configs, err error) { +func egressSvcsConfigs(ctx context.Context, cl client.Client, proxyGroupName, tsNamespace string) (cm *corev1.ConfigMap, cfgs egressservices.Configs, err error) { name := pgEgressCMName(proxyGroupName) cm = &corev1.ConfigMap{ ObjectMeta: metav1.ObjectMeta{ @@ -663,9 +665,9 @@ func egressSvcsConfigs(ctx context.Context, cl client.Client, proxyGroupName, ts if err != nil { return nil, nil, fmt.Errorf("error retrieving egress services ConfigMap %s: %v", name, err) } - cfgs = &egressservices.Configs{} + cfgs = egressservices.Configs{} if len(cm.BinaryData[egressservices.KeyEgressServices]) != 0 { - if err := json.Unmarshal(cm.BinaryData[egressservices.KeyEgressServices], cfgs); err != nil { + if err := json.Unmarshal(cm.BinaryData[egressservices.KeyEgressServices], &cfgs); err != nil { return nil, nil, fmt.Errorf("error unmarshaling egress services config %v: %w", cm.BinaryData[egressservices.KeyEgressServices], err) } } diff --git a/cmd/k8s-operator/egress-services_test.go b/cmd/k8s-operator/egress-services_test.go index 45861449191cb..a7dd79f7f1f84 100644 --- a/cmd/k8s-operator/egress-services_test.go +++ b/cmd/k8s-operator/egress-services_test.go @@ -21,6 +21,7 @@ import ( "k8s.io/apimachinery/pkg/util/intstr" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" + tsapi "tailscale.com/k8s-operator/apis/v1alpha1" "tailscale.com/kube/egressservices" "tailscale.com/tstest" @@ -203,8 +204,9 @@ func clusterIPSvc(name string, extNSvc *corev1.Service) *corev1.Service { Labels: labels, }, Spec: corev1.ServiceSpec{ - Type: corev1.ServiceTypeClusterIP, - Ports: ports, + Type: corev1.ServiceTypeClusterIP, + IPFamilyPolicy: new(corev1.IPFamilyPolicyPreferDualStack), + Ports: ports, }, } } @@ -243,7 +245,7 @@ func portsForEndpointSlice(svc *corev1.Service) []discoveryv1.EndpointPort { ports = append(ports, discoveryv1.EndpointPort{ Name: &p.Name, Protocol: &p.Protocol, - Port: pointer.ToInt32(p.TargetPort.IntVal), + Port: new(p.TargetPort.IntVal), }) } return ports @@ -283,11 +285,11 @@ func configFromCM(t *testing.T, cm *corev1.ConfigMap, svcName string) *egressser if !ok { return nil } - cfgs := &egressservices.Configs{} - if err := json.Unmarshal(cfgBs, cfgs); err != nil { + cfgs := egressservices.Configs{} + if err := json.Unmarshal(cfgBs, &cfgs); err != nil { t.Fatalf("error unmarshalling config: %v", err) } - cfg, ok := (*cfgs)[svcName] + cfg, ok := cfgs[svcName] if ok { return &cfg } diff --git a/cmd/k8s-operator/ingress-for-pg_test.go b/cmd/k8s-operator/ingress-for-pg_test.go index 3c9c839177bb5..8312dc5f70520 100644 --- a/cmd/k8s-operator/ingress-for-pg_test.go +++ b/cmd/k8s-operator/ingress-for-pg_test.go @@ -25,14 +25,14 @@ import ( "k8s.io/client-go/tools/record" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" + "tailscale.com/client/tailscale/v2" - "tailscale.com/internal/client/tailscale" "tailscale.com/ipn" tsoperator "tailscale.com/k8s-operator" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" + "tailscale.com/k8s-operator/tsclient" "tailscale.com/kube/kubetypes" "tailscale.com/tailcfg" - "tailscale.com/types/ptr" ) func TestIngressPGReconciler(t *testing.T) { @@ -49,7 +49,7 @@ func TestIngressPGReconciler(t *testing.T) { }, }, Spec: networkingv1.IngressSpec{ - IngressClassName: ptr.To("tailscale"), + IngressClassName: new("tailscale"), DefaultBackend: &networkingv1.IngressBackend{ Service: &networkingv1.IngressServiceBackend{ Name: "test", @@ -89,7 +89,7 @@ func TestIngressPGReconciler(t *testing.T) { expectReconciled(t, ingPGR, "default", "test-ingress") // Verify Tailscale Service uses custom tags - tsSvc, err := ft.GetVIPService(t.Context(), "svc:my-svc") + tsSvc, err := ft.VIPServices().Get(t.Context(), "svc:my-svc") if err != nil { t.Fatalf("getting Tailscale Service: %v", err) } @@ -116,7 +116,7 @@ func TestIngressPGReconciler(t *testing.T) { }, }, Spec: networkingv1.IngressSpec{ - IngressClassName: ptr.To("tailscale"), + IngressClassName: new("tailscale"), DefaultBackend: &networkingv1.IngressBackend{ Service: &networkingv1.IngressServiceBackend{ Name: "test", @@ -241,7 +241,7 @@ func TestIngressPGReconciler(t *testing.T) { }, }, Spec: networkingv1.IngressSpec{ - IngressClassName: ptr.To("tailscale"), + IngressClassName: new("tailscale"), DefaultBackend: &networkingv1.IngressBackend{ Service: &networkingv1.IngressServiceBackend{ Name: "test", @@ -260,7 +260,7 @@ func TestIngressPGReconciler(t *testing.T) { expectReconciled(t, ingPGR, ing3.Namespace, ing3.Name) // Delete the service from "control" - ft.vipServices = make(map[tailcfg.ServiceName]*tailscale.VIPService) + ft.vipServices = make(map[string]tailscale.VIPService) // Delete the ingress and confirm we don't get stuck due to the VIP service not existing. if err = fc.Delete(t.Context(), ing3); err != nil { @@ -285,7 +285,7 @@ func TestIngressPGReconciler_UpdateIngressHostname(t *testing.T) { }, }, Spec: networkingv1.IngressSpec{ - IngressClassName: ptr.To("tailscale"), + IngressClassName: new("tailscale"), DefaultBackend: &networkingv1.IngressBackend{ Service: &networkingv1.IngressServiceBackend{ Name: "test", @@ -320,11 +320,11 @@ func TestIngressPGReconciler_UpdateIngressHostname(t *testing.T) { verifyTailscaleService(t, ft, "svc:updated-svc", []string{"tcp:443"}) verifyTailscaledConfig(t, fc, "test-pg", []string{"svc:updated-svc"}) - _, err := ft.GetVIPService(context.Background(), "svc:my-svc") + _, err := ft.VIPServices().Get(context.Background(), "svc:my-svc") if err == nil { t.Fatalf("svc:my-svc not cleaned up") } - if !isErrorTailscaleServiceNotFound(err) { + if !tailscale.IsNotFound(err) { t.Fatalf("unexpected error: %v", err) } } @@ -340,7 +340,7 @@ func TestValidateIngress(t *testing.T) { }, }, Spec: networkingv1.IngressSpec{ - IngressClassName: ptr.To("tailscale"), + IngressClassName: new("tailscale"), TLS: []networkingv1.IngressTLS{ {Hosts: []string{"test"}}, }, @@ -474,7 +474,7 @@ func TestValidateIngress(t *testing.T) { }, }, Spec: networkingv1.IngressSpec{ - IngressClassName: ptr.To("tailscale"), + IngressClassName: new("tailscale"), TLS: []networkingv1.IngressTLS{ {Hosts: []string{"test"}}, }, @@ -521,7 +521,7 @@ func TestIngressPGReconciler_HTTPEndpoint(t *testing.T) { }, }, Spec: networkingv1.IngressSpec{ - IngressClassName: ptr.To("tailscale"), + IngressClassName: new("tailscale"), DefaultBackend: &networkingv1.IngressBackend{ Service: &networkingv1.IngressServiceBackend{ Name: "test", @@ -652,7 +652,7 @@ func TestIngressPGReconciler_HTTPRedirect(t *testing.T) { }, }, Spec: networkingv1.IngressSpec{ - IngressClassName: ptr.To("tailscale"), + IngressClassName: new("tailscale"), DefaultBackend: &networkingv1.IngressBackend{ Service: &networkingv1.IngressServiceBackend{ Name: "test", @@ -778,7 +778,7 @@ func TestIngressPGReconciler_HTTPEndpointAndRedirectConflict(t *testing.T) { }, }, Spec: networkingv1.IngressSpec{ - IngressClassName: ptr.To("tailscale"), + IngressClassName: new("tailscale"), DefaultBackend: &networkingv1.IngressBackend{ Service: &networkingv1.IngressServiceBackend{ Name: "test", @@ -869,7 +869,7 @@ func TestIngressPGReconciler_MultiCluster(t *testing.T) { }, }, Spec: networkingv1.IngressSpec{ - IngressClassName: ptr.To("tailscale"), + IngressClassName: new("tailscale"), TLS: []networkingv1.IngressTLS{ {Hosts: []string{"my-svc"}}, }, @@ -878,20 +878,18 @@ func TestIngressPGReconciler_MultiCluster(t *testing.T) { mustCreate(t, fc, ing) // Simulate existing Tailscale Service from another cluster - existingVIPSvc := &tailscale.VIPService{ + existingVIPSvc := tailscale.VIPService{ Name: "svc:my-svc", Annotations: map[string]string{ ownerAnnotation: `{"ownerrefs":[{"operatorID":"operator-2"}]}`, }, } - ft.vipServices = map[tailcfg.ServiceName]*tailscale.VIPService{ - "svc:my-svc": existingVIPSvc, - } + ft.VIPServices().CreateOrUpdate(t.Context(), existingVIPSvc) // Verify reconciliation adds our operator reference expectReconciled(t, ingPGR, "default", "test-ingress") - tsSvc, err := ft.GetVIPService(context.Background(), "svc:my-svc") + tsSvc, err := ft.VIPServices().Get(context.Background(), "svc:my-svc") if err != nil { t.Fatalf("getting Tailscale Service: %v", err) } @@ -918,7 +916,7 @@ func TestIngressPGReconciler_MultiCluster(t *testing.T) { } expectRequeue(t, ingPGR, "default", "test-ingress") - tsSvc, err = ft.GetVIPService(context.Background(), "svc:my-svc") + tsSvc, err = ft.VIPServices().Get(context.Background(), "svc:my-svc") if err != nil { t.Fatalf("getting Tailscale Service after deletion: %v", err) } @@ -1025,7 +1023,7 @@ func populateTLSSecret(t *testing.T, c client.Client, pgName, domain string) { func verifyTailscaleService(t *testing.T, ft *fakeTSClient, serviceName string, wantPorts []string) { t.Helper() - tsSvc, err := ft.GetVIPService(context.Background(), tailcfg.ServiceName(serviceName)) + tsSvc, err := ft.VIPServices().Get(context.Background(), serviceName) if err != nil { t.Fatalf("getting Tailscale Service %q: %v", serviceName, err) } @@ -1108,7 +1106,7 @@ func verifyTailscaledConfig(t *testing.T, fc client.Client, pgName string, expec Labels: pgSecretLabels(pgName, kubetypes.LabelSecretTypeConfig), }, Data: map[string][]byte{ - tsoperator.TailscaledConfigFileName(pgMinCapabilityVersion): []byte(fmt.Sprintf(`{"Version":""%s}`, expected)), + tsoperator.TailscaledConfigFileName(pgMinCapabilityVersion): fmt.Appendf(nil, `{"Version":""%s}`, expected), }, }) } @@ -1204,7 +1202,9 @@ func setupIngressTest(t *testing.T) (*HAIngressReconciler, client.Client, *fakeT fakeTsnetServer := &fakeTSNetServer{certDomains: []string{"foo.com"}} - ft := &fakeTSClient{} + ft := &fakeTSClient{ + vipServices: make(map[string]tailscale.VIPService), + } zl, err := zap.NewDevelopment() if err != nil { t.Fatal(err) @@ -1212,7 +1212,7 @@ func setupIngressTest(t *testing.T) (*HAIngressReconciler, client.Client, *fakeT ingPGR := &HAIngressReconciler{ Client: fc, - tsClient: ft, + clients: tsclient.NewProvider(ft), defaultTags: []string{"tag:k8s"}, tsNamespace: "operator-ns", tsnetServer: fakeTsnetServer, diff --git a/cmd/k8s-operator/ingress_test.go b/cmd/k8s-operator/ingress_test.go index aac40897cc88e..c2a1198cce539 100644 --- a/cmd/k8s-operator/ingress_test.go +++ b/cmd/k8s-operator/ingress_test.go @@ -21,17 +21,21 @@ import ( "k8s.io/client-go/tools/record" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" + "tailscale.com/client/tailscale/v2" + "tailscale.com/ipn" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" + "tailscale.com/k8s-operator/tsclient" "tailscale.com/kube/kubetypes" "tailscale.com/tstest" - "tailscale.com/types/ptr" "tailscale.com/util/mak" ) func TestTailscaleIngress(t *testing.T) { fc := fake.NewFakeClient(ingressClass()) - ft := &fakeTSClient{} + ft := &fakeTSClient{ + vipServices: make(map[string]tailscale.VIPService), + } fakeTsnetServer := &fakeTSNetServer{certDomains: []string{"foo.com"}} zl, err := zap.NewDevelopment() if err != nil { @@ -42,7 +46,7 @@ func TestTailscaleIngress(t *testing.T) { ingressClassName: "tailscale", ssr: &tailscaleSTSReconciler{ Client: fc, - tsClient: ft, + clients: tsclient.NewProvider(ft), tsnetServer: fakeTsnetServer, defaultTags: []string{"tag:k8s"}, operatorNamespace: "operator-ns", @@ -59,7 +63,7 @@ func TestTailscaleIngress(t *testing.T) { fullName, shortName := findGenName(t, fc, "default", "test", "ingress") opts := configOpts{ - replicas: ptr.To[int32](1), + replicas: new(int32(1)), stsName: shortName, secretName: fullName, namespace: "default", @@ -109,7 +113,7 @@ func TestTailscaleIngress(t *testing.T) { // 4. Resources get cleaned up when Ingress class is unset mustUpdate(t, fc, "default", "test", func(ing *networkingv1.Ingress) { - ing.Spec.IngressClassName = ptr.To("nginx") + ing.Spec.IngressClassName = new("nginx") }) expectReconciled(t, ingR, "default", "test") expectReconciled(t, ingR, "default", "test") // deleting Ingress STS requires two reconciles @@ -131,7 +135,7 @@ func TestTailscaleIngressHostname(t *testing.T) { ingressClassName: "tailscale", ssr: &tailscaleSTSReconciler{ Client: fc, - tsClient: ft, + clients: tsclient.NewProvider(ft), tsnetServer: fakeTsnetServer, defaultTags: []string{"tag:k8s"}, operatorNamespace: "operator-ns", @@ -270,7 +274,7 @@ func TestTailscaleIngressWithProxyClass(t *testing.T) { ingressClassName: "tailscale", ssr: &tailscaleSTSReconciler{ Client: fc, - tsClient: ft, + clients: tsclient.NewProvider(ft), tsnetServer: fakeTsnetServer, defaultTags: []string{"tag:k8s"}, operatorNamespace: "operator-ns", @@ -379,7 +383,7 @@ func TestTailscaleIngressWithServiceMonitor(t *testing.T) { ingressClassName: "tailscale", ssr: &tailscaleSTSReconciler{ Client: fc, - tsClient: ft, + clients: tsclient.NewProvider(ft), tsnetServer: fakeTsnetServer, defaultTags: []string{"tag:k8s"}, operatorNamespace: "operator-ns", @@ -531,7 +535,7 @@ func TestIngressProxyClassAnnotation(t *testing.T) { ingressClassName: "tailscale", ssr: &tailscaleSTSReconciler{ Client: fc, - tsClient: &fakeTSClient{}, + clients: tsclient.NewProvider(&fakeTSClient{}), tsnetServer: &fakeTSNetServer{certDomains: []string{"test-host"}}, defaultTags: []string{"tag:test"}, operatorNamespace: "operator-ns", @@ -602,7 +606,7 @@ func TestIngressLetsEncryptStaging(t *testing.T) { ingressClassName: "tailscale", ssr: &tailscaleSTSReconciler{ Client: fc, - tsClient: &fakeTSClient{}, + clients: tsclient.NewProvider(&fakeTSClient{}), tsnetServer: &fakeTSNetServer{certDomains: []string{"test-host"}}, defaultTags: []string{"tag:test"}, operatorNamespace: "operator-ns", @@ -639,7 +643,7 @@ func TestEmptyPath(t *testing.T) { name: "empty_path_with_prefix_type", paths: []networkingv1.HTTPIngressPath{ { - PathType: ptrPathType(networkingv1.PathTypePrefix), + PathType: new(networkingv1.PathTypePrefix), Path: "", Backend: *backend(), }, @@ -652,7 +656,7 @@ func TestEmptyPath(t *testing.T) { name: "empty_path_with_implementation_specific_type", paths: []networkingv1.HTTPIngressPath{ { - PathType: ptrPathType(networkingv1.PathTypeImplementationSpecific), + PathType: new(networkingv1.PathTypeImplementationSpecific), Path: "", Backend: *backend(), }, @@ -665,7 +669,7 @@ func TestEmptyPath(t *testing.T) { name: "empty_path_with_exact_type", paths: []networkingv1.HTTPIngressPath{ { - PathType: ptrPathType(networkingv1.PathTypeExact), + PathType: new(networkingv1.PathTypeExact), Path: "", Backend: *backend(), }, @@ -679,12 +683,12 @@ func TestEmptyPath(t *testing.T) { name: "two_competing_but_not_identical_paths_including_one_empty", paths: []networkingv1.HTTPIngressPath{ { - PathType: ptrPathType(networkingv1.PathTypeImplementationSpecific), + PathType: new(networkingv1.PathTypeImplementationSpecific), Path: "", Backend: *backend(), }, { - PathType: ptrPathType(networkingv1.PathTypeImplementationSpecific), + PathType: new(networkingv1.PathTypeImplementationSpecific), Path: "/", Backend: *backend(), }, @@ -711,7 +715,7 @@ func TestEmptyPath(t *testing.T) { ingressClassName: "tailscale", ssr: &tailscaleSTSReconciler{ Client: fc, - tsClient: ft, + clients: tsclient.NewProvider(ft), tsnetServer: fakeTsnetServer, defaultTags: []string{"tag:k8s"}, operatorNamespace: "operator-ns", @@ -760,11 +764,6 @@ func TestEmptyPath(t *testing.T) { } } -// ptrPathType is a helper function to return a pointer to the pathtype string (required for TestEmptyPath) -func ptrPathType(p networkingv1.PathType) *networkingv1.PathType { - return &p -} - func ingressClass() *networkingv1.IngressClass { return &networkingv1.IngressClass{ ObjectMeta: metav1.ObjectMeta{Name: "tailscale"}, @@ -799,7 +798,7 @@ func ingress() *networkingv1.Ingress { UID: "1234-UID", }, Spec: networkingv1.IngressSpec{ - IngressClassName: ptr.To("tailscale"), + IngressClassName: new("tailscale"), DefaultBackend: backend(), TLS: []networkingv1.IngressTLS{ {Hosts: []string{"default-test"}}, @@ -817,7 +816,7 @@ func ingressWithPaths(paths []networkingv1.HTTPIngressPath) *networkingv1.Ingres UID: types.UID("1234-UID"), }, Spec: networkingv1.IngressSpec{ - IngressClassName: ptr.To("tailscale"), + IngressClassName: new("tailscale"), Rules: []networkingv1.IngressRule{ { Host: "foo.tailnetxyz.ts.net", @@ -859,7 +858,7 @@ func TestTailscaleIngressWithHTTPRedirect(t *testing.T) { ingressClassName: "tailscale", ssr: &tailscaleSTSReconciler{ Client: fc, - tsClient: ft, + clients: tsclient.NewProvider(ft), tsnetServer: fakeTsnetServer, defaultTags: []string{"tag:k8s"}, operatorNamespace: "operator-ns", @@ -878,7 +877,7 @@ func TestTailscaleIngressWithHTTPRedirect(t *testing.T) { fullName, shortName := findGenName(t, fc, "default", "test", "ingress") opts := configOpts{ - replicas: ptr.To[int32](1), + replicas: new(int32(1)), stsName: shortName, secretName: fullName, namespace: "default", diff --git a/cmd/k8s-operator/metrics_resources.go b/cmd/k8s-operator/metrics_resources.go index afb055018bb13..4384f4cba40bd 100644 --- a/cmd/k8s-operator/metrics_resources.go +++ b/cmd/k8s-operator/metrics_resources.go @@ -8,6 +8,7 @@ package main import ( "context" "fmt" + "maps" "reflect" "go.uber.org/zap" @@ -18,6 +19,7 @@ import ( "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/client" + kube "tailscale.com/k8s-operator" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" "tailscale.com/kube/kubetypes" ) @@ -226,13 +228,13 @@ func metricsResourceLabels(opts *metricsOpts) map[string]string { kubetypes.LabelManaged: "true", labelMetricsTarget: opts.proxyStsName, labelPromProxyType: opts.proxyType, - labelPromProxyParentName: opts.proxyLabels[LabelParentName], + labelPromProxyParentName: kube.TruncateLabelValue(opts.proxyLabels[LabelParentName]), } // Include namespace label for proxies created for a namespaced type. if isNamespacedProxyType(opts.proxyType) { - lbls[labelPromProxyParentNamespace] = opts.proxyLabels[LabelParentNamespace] + lbls[labelPromProxyParentNamespace] = kube.TruncateLabelValue(opts.proxyLabels[LabelParentNamespace]) } - lbls[labelPromJob] = promJobName(opts) + lbls[labelPromJob] = kube.TruncateLabelValue(promJobName(opts)) return lbls } @@ -249,11 +251,11 @@ func promJobName(opts *metricsOpts) string { func metricsSvcSelector(proxyLabels map[string]string, proxyType string) map[string]string { sel := map[string]string{ labelPromProxyType: proxyType, - labelPromProxyParentName: proxyLabels[LabelParentName], + labelPromProxyParentName: kube.TruncateLabelValue(proxyLabels[LabelParentName]), } // Include namespace label for proxies created for a namespaced type. if isNamespacedProxyType(proxyType) { - sel[labelPromProxyParentNamespace] = proxyLabels[LabelParentNamespace] + sel[labelPromProxyParentNamespace] = kube.TruncateLabelValue(proxyLabels[LabelParentNamespace]) } return sel } @@ -286,11 +288,7 @@ func isNamespacedProxyType(typ string) bool { func mergeMapKeys(a, b map[string]string) map[string]string { m := make(map[string]string, len(a)+len(b)) - for key, val := range b { - m[key] = val - } - for key, val := range a { - m[key] = val - } + maps.Copy(m, b) + maps.Copy(m, a) return m } diff --git a/cmd/k8s-operator/nameserver.go b/cmd/k8s-operator/nameserver.go index 522b460031530..f5565e5d30cee 100644 --- a/cmd/k8s-operator/nameserver.go +++ b/cmd/k8s-operator/nameserver.go @@ -31,7 +31,6 @@ import ( tsapi "tailscale.com/k8s-operator/apis/v1alpha1" "tailscale.com/kube/kubetypes" "tailscale.com/tstime" - "tailscale.com/types/ptr" "tailscale.com/util/clientmetric" "tailscale.com/util/set" ) @@ -191,6 +190,8 @@ func (a *NameserverReconciler) maybeProvision(ctx context.Context, tsDNSCfg *tsa } if tsDNSCfg.Spec.Nameserver.Pod != nil { dCfg.tolerations = tsDNSCfg.Spec.Nameserver.Pod.Tolerations + dCfg.affinity = tsDNSCfg.Spec.Nameserver.Pod.Affinity + dCfg.nodeSelector = tsDNSCfg.Spec.Nameserver.Pod.NodeSelector } for _, deployable := range []deployable{saDeployable, deployDeployable, svcDeployable, cmDeployable} { @@ -218,14 +219,16 @@ type deployable struct { } type deployConfig struct { - replicas int32 - imageRepo string - imageTag string - labels map[string]string - ownerRefs []metav1.OwnerReference - namespace string - clusterIP string - tolerations []corev1.Toleration + replicas int32 + imageRepo string + imageTag string + labels map[string]string + ownerRefs []metav1.OwnerReference + namespace string + clusterIP string + tolerations []corev1.Toleration + affinity *corev1.Affinity + nodeSelector map[string]string } var ( @@ -245,12 +248,14 @@ var ( if err := yaml.Unmarshal(deployYaml, &d); err != nil { return fmt.Errorf("error unmarshalling Deployment yaml: %w", err) } - d.Spec.Replicas = ptr.To(cfg.replicas) + d.Spec.Replicas = new(cfg.replicas) d.Spec.Template.Spec.Containers[0].Image = fmt.Sprintf("%s:%s", cfg.imageRepo, cfg.imageTag) d.ObjectMeta.Namespace = cfg.namespace d.ObjectMeta.Labels = cfg.labels d.ObjectMeta.OwnerReferences = cfg.ownerRefs d.Spec.Template.Spec.Tolerations = cfg.tolerations + d.Spec.Template.Spec.Affinity = cfg.affinity + d.Spec.Template.Spec.NodeSelector = cfg.nodeSelector updateF := func(oldD *appsv1.Deployment) { oldD.Spec = d.Spec } diff --git a/cmd/k8s-operator/nameserver_test.go b/cmd/k8s-operator/nameserver_test.go index 531190cf21dc2..3ec00d5ed8859 100644 --- a/cmd/k8s-operator/nameserver_test.go +++ b/cmd/k8s-operator/nameserver_test.go @@ -23,7 +23,6 @@ import ( operatorutils "tailscale.com/k8s-operator" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" "tailscale.com/tstest" - "tailscale.com/types/ptr" "tailscale.com/util/mak" ) @@ -35,7 +34,7 @@ func TestNameserverReconciler(t *testing.T) { }, Spec: tsapi.DNSConfigSpec{ Nameserver: &tsapi.Nameserver{ - Replicas: ptr.To[int32](3), + Replicas: new(int32(3)), Image: &tsapi.NameserverImage{ Repo: "test", Tag: "v0.0.1", @@ -44,6 +43,9 @@ func TestNameserverReconciler(t *testing.T) { ClusterIP: "5.4.3.2", }, Pod: &tsapi.NameserverPod{ + NodeSelector: map[string]string{ + "foo": "bar", + }, Tolerations: []corev1.Toleration{ { Key: "some-key", @@ -52,6 +54,23 @@ func TestNameserverReconciler(t *testing.T) { Effect: corev1.TaintEffectNoSchedule, }, }, + Affinity: &corev1.Affinity{ + NodeAffinity: &corev1.NodeAffinity{ + RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{ + NodeSelectorTerms: []corev1.NodeSelectorTerm{ + { + MatchExpressions: []corev1.NodeSelectorRequirement{ + { + Key: "some-key", + Operator: corev1.NodeSelectorOpIn, + Values: []string{"some-value"}, + }, + }, + }, + }, + }, + }, + }, }, }, }, @@ -81,13 +100,13 @@ func TestNameserverReconciler(t *testing.T) { nameserverLabels := nameserverResourceLabels(dnsConfig.Name, tsNamespace) wantsDeploy := &appsv1.Deployment{ObjectMeta: metav1.ObjectMeta{Name: "nameserver", Namespace: tsNamespace}, TypeMeta: metav1.TypeMeta{Kind: "Deployment", APIVersion: appsv1.SchemeGroupVersion.Identifier()}} - t.Run("deployment has expected fields", func(t *testing.T) { + t.Run("deployment-expected-fields", func(t *testing.T) { if err = yaml.Unmarshal(deployYaml, wantsDeploy); err != nil { t.Fatalf("unmarshalling yaml: %v", err) } wantsDeploy.OwnerReferences = []metav1.OwnerReference{*ownerReference} wantsDeploy.Spec.Template.Spec.Containers[0].Image = "test:v0.0.1" - wantsDeploy.Spec.Replicas = ptr.To[int32](3) + wantsDeploy.Spec.Replicas = new(int32(3)) wantsDeploy.Namespace = tsNamespace wantsDeploy.ObjectMeta.Labels = nameserverLabels wantsDeploy.Spec.Template.Spec.Tolerations = []corev1.Toleration{ @@ -98,12 +117,32 @@ func TestNameserverReconciler(t *testing.T) { Effect: corev1.TaintEffectNoSchedule, }, } + wantsDeploy.Spec.Template.Spec.Affinity = &corev1.Affinity{ + NodeAffinity: &corev1.NodeAffinity{ + RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{ + NodeSelectorTerms: []corev1.NodeSelectorTerm{ + { + MatchExpressions: []corev1.NodeSelectorRequirement{ + { + Key: "some-key", + Operator: corev1.NodeSelectorOpIn, + Values: []string{"some-value"}, + }, + }, + }, + }, + }, + }, + } + wantsDeploy.Spec.Template.Spec.NodeSelector = map[string]string{ + "foo": "bar", + } expectEqual(t, fc, wantsDeploy) }) wantsSvc := &corev1.Service{ObjectMeta: metav1.ObjectMeta{Name: "nameserver", Namespace: tsNamespace}, TypeMeta: metav1.TypeMeta{Kind: "Service", APIVersion: corev1.SchemeGroupVersion.Identifier()}} - t.Run("service has expected fields", func(t *testing.T) { + t.Run("service-expected-fields", func(t *testing.T) { if err = yaml.Unmarshal(svcYaml, wantsSvc); err != nil { t.Fatalf("unmarshalling yaml: %v", err) } @@ -114,7 +153,7 @@ func TestNameserverReconciler(t *testing.T) { expectEqual(t, fc, wantsSvc) }) - t.Run("dns config status is set", func(t *testing.T) { + t.Run("dns-config-status-set", func(t *testing.T) { // Verify that DNSConfig advertizes the nameserver's Service IP address, // has the ready status condition and tailscale finalizer. mustUpdate(t, fc, "tailscale", "nameserver", func(svc *corev1.Service) { @@ -137,7 +176,7 @@ func TestNameserverReconciler(t *testing.T) { expectEqual(t, fc, dnsConfig) }) - t.Run("nameserver image can be updated", func(t *testing.T) { + t.Run("nameserver-image-updated", func(t *testing.T) { // Verify that nameserver image gets updated to match DNSConfig spec. mustUpdate(t, fc, "", "test", func(dnsCfg *tsapi.DNSConfig) { dnsCfg.Spec.Nameserver.Image.Tag = "v0.0.2" @@ -147,7 +186,7 @@ func TestNameserverReconciler(t *testing.T) { expectEqual(t, fc, wantsDeploy) }) - t.Run("reconciler does not overwrite custom configuration", func(t *testing.T) { + t.Run("reconciler-preserves-custom-config", func(t *testing.T) { // Verify that when another actor sets ConfigMap data, it does not get // overwritten by nameserver reconciler. dnsRecords := &operatorutils.Records{Version: "v1alpha1", IP4: map[string][]string{"foo.ts.net": {"1.2.3.4"}}} @@ -176,7 +215,7 @@ func TestNameserverReconciler(t *testing.T) { expectEqual(t, fc, wantCm) }) - t.Run("uses default nameserver image", func(t *testing.T) { + t.Run("uses-default-nameserver-image", func(t *testing.T) { // Verify that if dnsconfig.spec.nameserver.image.{repo,tag} are unset, // the nameserver image defaults to tailscale/k8s-nameserver:unstable. mustUpdate(t, fc, "", "test", func(dnsCfg *tsapi.DNSConfig) { diff --git a/cmd/k8s-operator/operator.go b/cmd/k8s-operator/operator.go index 1060c6f3da9e7..9f9c719973895 100644 --- a/cmd/k8s-operator/operator.go +++ b/cmd/k8s-operator/operator.go @@ -46,9 +46,9 @@ import ( "sigs.k8s.io/controller-runtime/pkg/manager/signals" "sigs.k8s.io/controller-runtime/pkg/predicate" "sigs.k8s.io/controller-runtime/pkg/reconcile" + "tailscale.com/client/tailscale/v2" "tailscale.com/client/local" - "tailscale.com/client/tailscale" "tailscale.com/envknob" "tailscale.com/hostinfo" "tailscale.com/ipn" @@ -57,6 +57,7 @@ import ( tsapi "tailscale.com/k8s-operator/apis/v1alpha1" "tailscale.com/k8s-operator/reconciler/proxygrouppolicy" "tailscale.com/k8s-operator/reconciler/tailnet" + "tailscale.com/k8s-operator/tsclient" "tailscale.com/kube/kubetypes" "tailscale.com/tsnet" "tailscale.com/tstime" @@ -77,11 +78,13 @@ import ( // Generate CRD API docs. //go:generate go run github.com/elastic/crd-ref-docs --renderer=markdown --source-path=../../k8s-operator/apis/ --config=../../k8s-operator/api-docs-config.yaml --output-path=../../k8s-operator/api.md -func main() { - // Required to use our client API. We're fine with the instability since the - // client lives in the same repo as this code. - tailscale.I_Acknowledge_This_API_Is_Unstable = true +const ( + indexServiceProxyClass = ".metadata.annotations.service-proxy-class" + indexServiceExposed = ".metadata.annotations.service-expose" + indexServiceType = ".metadata.annotations.service-type" +) +func main() { var ( tsNamespace = defaultEnv("OPERATOR_NAMESPACE", "") tslogging = defaultEnv("OPERATOR_LOGGING", "info") @@ -149,7 +152,7 @@ func main() { })) } - rOpts := reconcilerOpts{ + runReconcilers(reconcilerOpts{ log: zlog, tsServer: s, tsClient: tsc, @@ -164,15 +167,14 @@ func main() { defaultProxyClass: defaultProxyClass, loginServer: loginServer, ingressClassName: ingressClassName, - } - runReconcilers(rOpts) + }) } // initTSNet initializes the tsnet.Server and logs in to Tailscale. If CLIENT_ID // is set, it authenticates to the Tailscale API using the federated OIDC workload // identity flow. Otherwise, it uses the CLIENT_ID_FILE and CLIENT_SECRET_FILE // environment variables to authenticate with static credentials. -func initTSNet(zlog *zap.SugaredLogger, loginServer string) (*tsnet.Server, tsClient) { +func initTSNet(zlog *zap.SugaredLogger, loginServer string) (*tsnet.Server, *tailscale.Client) { var ( clientID = defaultEnv("CLIENT_ID", "") // Used for workload identity federation. clientIDPath = defaultEnv("CLIENT_ID_FILE", "") // Used for static client credentials. @@ -181,19 +183,23 @@ func initTSNet(zlog *zap.SugaredLogger, loginServer string) (*tsnet.Server, tsCl kubeSecret = defaultEnv("OPERATOR_SECRET", "") operatorTags = defaultEnv("OPERATOR_INITIAL_TAGS", "tag:k8s-operator") ) + startlog := zlog.Named("startup") if clientID == "" && (clientIDPath == "" || clientSecretPath == "") { startlog.Fatalf("CLIENT_ID_FILE and CLIENT_SECRET_FILE must be set") // TODO(tomhjp): error message can mention WIF once it's publicly available. } + tsc, err := newTSClient(zlog.Named("ts-api-client"), clientID, clientIDPath, clientSecretPath, loginServer) if err != nil { startlog.Fatalf("error creating Tailscale client: %v", err) } + s := &tsnet.Server{ Hostname: hostname, Logf: zlog.Named("tailscaled").Debugf, ControlURL: loginServer, } + if p := os.Getenv("TS_PORT"); p != "" { port, err := strconv.ParseUint(p, 10, 16) if err != nil { @@ -201,6 +207,7 @@ func initTSNet(zlog *zap.SugaredLogger, loginServer string) (*tsnet.Server, tsCl } s.Port = uint16(port) } + if kubeSecret != "" { st, err := kubestore.New(logger.Discard, kubeSecret) if err != nil { @@ -208,6 +215,7 @@ func initTSNet(zlog *zap.SugaredLogger, loginServer string) (*tsnet.Server, tsCl } s.Store = st } + if err := s.Start(); err != nil { startlog.Fatalf("starting tailscale server: %v", err) } @@ -233,27 +241,29 @@ waitOnline: if loginDone { break } - caps := tailscale.KeyCapabilities{ - Devices: tailscale.KeyDeviceCapabilities{ - Create: tailscale.KeyDeviceCreateCapabilities{ - Reusable: false, - Preauthorized: true, - Tags: strings.Split(operatorTags, ","), - }, - }, - } - authkey, _, err := tsc.CreateKey(ctx, caps) + + var caps tailscale.KeyCapabilities + caps.Devices.Create.Reusable = false + caps.Devices.Create.Preauthorized = true + caps.Devices.Create.Tags = strings.Split(operatorTags, ",") + + authKey, err := tsc.Keys().CreateAuthKey(ctx, tailscale.CreateKeyRequest{Capabilities: caps}) if err != nil { startlog.Fatalf("creating operator authkey: %v", err) } - if err := lc.Start(ctx, ipn.Options{ - AuthKey: authkey, - }); err != nil { + + opts := ipn.Options{ + AuthKey: authKey.Key, + } + + if err = lc.Start(ctx, opts); err != nil { startlog.Fatalf("starting tailscale: %v", err) } - if err := lc.StartLoginInteractive(ctx); err != nil { + + if err = lc.StartLoginInteractive(ctx); err != nil { startlog.Fatalf("starting login: %v", err) } + startlog.Debugf("requested login by authkey") loginDone = true case "NeedsMachineAuth": @@ -280,6 +290,12 @@ func serviceManagedResourceFilterPredicate() predicate.Predicate { }) } +type ( + ClientProvider interface { + For(tailnet string) (tsclient.Client, error) + } +) + // runReconcilers starts the controller-runtime manager and registers the // ServiceReconciler. It blocks forever. func runReconcilers(opts reconcilerOpts) { @@ -328,11 +344,14 @@ func runReconcilers(opts reconcilerOpts) { startlog.Fatalf("could not create manager: %v", err) } + clients := tsclient.NewProvider(tsclient.Wrap(opts.tsClient)) + tailnetOptions := tailnet.ReconcilerOptions{ Client: mgr.GetClient(), TailscaleNamespace: opts.tailscaleNamespace, Clock: tstime.DefaultClock{}, Logger: opts.log, + Registry: clients, } if err = tailnet.NewReconciler(tailnetOptions).Register(mgr); err != nil { @@ -351,13 +370,18 @@ func runReconcilers(opts reconcilerOpts) { svcChildFilter := handler.EnqueueRequestsFromMapFunc(managedResourceHandlerForType("svc")) // If a ProxyClass changes, enqueue all Services labeled with that // ProxyClass's name. - proxyClassFilterForSvc := handler.EnqueueRequestsFromMapFunc(proxyClassHandlerForSvc(mgr.GetClient(), startlog)) + proxyClassFilterForSvc := handler.EnqueueRequestsFromMapFunc(proxyClassHandlerForSvc( + mgr.GetClient(), + startlog, + opts.defaultProxyClass, + opts.proxyActAsDefaultLoadBalancer, + )) eventRecorder := mgr.GetEventRecorderFor("tailscale-operator") ssr := &tailscaleSTSReconciler{ Client: mgr.GetClient(), tsnetServer: opts.tsServer, - tsClient: opts.tsClient, + clients: clients, defaultTags: strings.Split(opts.proxyTags, ","), operatorNamespace: opts.tailscaleNamespace, proxyImage: opts.proxyImage, @@ -389,6 +413,18 @@ func runReconcilers(opts reconcilerOpts) { if err := mgr.GetFieldIndexer().IndexField(context.Background(), new(corev1.Service), indexServiceProxyClass, indexProxyClass); err != nil { startlog.Fatalf("failed setting up ProxyClass indexer for Services: %v", err) } + if opts.defaultProxyClass != "" { + // If a default ProxyClass is specified, we'll need to list all objects + // that could be affected. For L3 ingress, this is Services with the + // "tailscale.com/expose" annotation and LoadBalancer services (either + // with the loadBalancerClass "tailscale", or unset if we're the default). + if err := mgr.GetFieldIndexer().IndexField(context.Background(), new(corev1.Service), indexServiceExposed, indexExposed); err != nil { + startlog.Fatalf("failed setting up exposed indexer for Services: %v", err) + } + if err := mgr.GetFieldIndexer().IndexField(context.Background(), new(corev1.Service), indexServiceType, indexType); err != nil { + startlog.Fatalf("failed setting up type indexer for Services: %v", err) + } + } ingressChildFilter := handler.EnqueueRequestsFromMapFunc(managedResourceHandlerForType("ingress")) // If a ProxyClassChanges, enqueue all Ingresses labeled with that @@ -437,7 +473,7 @@ func runReconcilers(opts reconcilerOpts) { Watches(&tsapi.ProxyGroup{}, ingressProxyGroupFilter). Complete(&HAIngressReconciler{ recorder: eventRecorder, - tsClient: opts.tsClient, + clients: clients, tsnetServer: opts.tsServer, defaultTags: strings.Split(opts.proxyTags, ","), Client: mgr.GetClient(), @@ -463,7 +499,7 @@ func runReconcilers(opts reconcilerOpts) { Watches(&discoveryv1.EndpointSlice{}, ingressSvcFromEpsFilter). Complete(&HAServiceReconciler{ recorder: eventRecorder, - tsClient: opts.tsClient, + clients: clients, defaultTags: strings.Split(opts.proxyTags, ","), Client: mgr.GetClient(), logger: opts.log.Named("service-pg-reconciler"), @@ -656,13 +692,14 @@ func runReconcilers(opts reconcilerOpts) { Watches(&rbacv1.Role{}, recorderFilter). Watches(&rbacv1.RoleBinding{}, recorderFilter). Complete(&RecorderReconciler{ - recorder: eventRecorder, - tsNamespace: opts.tailscaleNamespace, - Client: mgr.GetClient(), - log: opts.log.Named("recorder-reconciler"), - clock: tstime.DefaultClock{}, - tsClient: opts.tsClient, - loginServer: opts.loginServer, + recorder: eventRecorder, + tsNamespace: opts.tailscaleNamespace, + Client: mgr.GetClient(), + log: opts.log.Named("recorder-reconciler"), + clock: tstime.DefaultClock{}, + clients: clients, + authKeyRateLimits: make(map[string]*rate.Limiter), + authKeyReissuing: make(map[string]bool), }) if err != nil { startlog.Fatalf("could not create Recorder reconciler: %v", err) @@ -683,7 +720,7 @@ func runReconcilers(opts reconcilerOpts) { Client: mgr.GetClient(), recorder: eventRecorder, logger: opts.log.Named("kube-apiserver-ts-service-reconciler"), - tsClient: opts.tsClient, + clients: clients, tsNamespace: opts.tailscaleNamespace, defaultTags: strings.Split(opts.proxyTags, ","), operatorID: id, @@ -715,7 +752,7 @@ func runReconcilers(opts reconcilerOpts) { Client: mgr.GetClient(), log: opts.log.Named("proxygroup-reconciler"), clock: tstime.DefaultClock{}, - tsClient: opts.tsClient, + clients: clients, tsNamespace: opts.tailscaleNamespace, tsProxyImage: opts.proxyImage, @@ -740,7 +777,7 @@ func runReconcilers(opts reconcilerOpts) { type reconcilerOpts struct { log *zap.SugaredLogger tsServer *tsnet.Server - tsClient tsClient + tsClient *tailscale.Client tailscaleNamespace string // namespace in which operator resources will be deployed restConfig *rest.Config // config for connecting to the kube API server proxyImage string // : @@ -910,10 +947,27 @@ func indexProxyClass(o client.Object) []string { return []string{o.GetAnnotations()[LabelAnnotationProxyClass]} } +func indexExposed(o client.Object) []string { + if o.GetAnnotations()[AnnotationExpose] != "true" { + return nil + } + + return []string{o.GetAnnotations()[AnnotationExpose]} +} + +func indexType(o client.Object) []string { + svc, ok := o.(*corev1.Service) + if !ok { + return nil + } + + return []string{string(svc.Spec.Type)} +} + // proxyClassHandlerForSvc returns a handler that, for a given ProxyClass, // returns a list of reconcile requests for all Services labeled with // tailscale.com/proxy-class: . -func proxyClassHandlerForSvc(cl client.Client, logger *zap.SugaredLogger) handler.MapFunc { +func proxyClassHandlerForSvc(cl client.Client, logger *zap.SugaredLogger, defaultProxyClass string, isDefaultLoadBalancer bool) handler.MapFunc { return func(ctx context.Context, o client.Object) []reconcile.Request { svcList := new(corev1.ServiceList) labels := map[string]string{ @@ -932,13 +986,12 @@ func proxyClassHandlerForSvc(cl client.Client, logger *zap.SugaredLogger) handle seenSvcs.Add(fmt.Sprintf("%s/%s", svc.Namespace, svc.Name)) } - svcAnnotationList := new(corev1.ServiceList) - if err := cl.List(ctx, svcAnnotationList, client.MatchingFields{indexServiceProxyClass: o.GetName()}); err != nil { + if err := cl.List(ctx, svcList, client.MatchingFields{indexServiceProxyClass: o.GetName()}); err != nil { logger.Debugf("error listing Services for ProxyClass: %v", err) return nil } - for _, svc := range svcAnnotationList.Items { + for _, svc := range svcList.Items { nsname := fmt.Sprintf("%s/%s", svc.Namespace, svc.Name) if seenSvcs.Contains(nsname) { continue @@ -948,6 +1001,36 @@ func proxyClassHandlerForSvc(cl client.Client, logger *zap.SugaredLogger) handle seenSvcs.Add(nsname) } + if o.GetName() == defaultProxyClass { + // For the default ProxyClass, we also need to reconcile all exposed + // Services that don't have an explicit ProxyClass set. + for _, matcher := range []client.ListOption{ + client.MatchingFields{indexServiceExposed: "true"}, + client.MatchingFields{indexServiceType: string(corev1.ServiceTypeLoadBalancer)}, + } { + if err := cl.List(ctx, svcList, matcher); err != nil { + logger.Debugf("error listing exposed Services for ProxyClass: %v", err) + return nil + } + + for _, svc := range svcList.Items { + if hasProxyClassAnnotation(&svc) { + continue + } + if !shouldExpose(&svc, isDefaultLoadBalancer) { + continue + } + nsname := fmt.Sprintf("%s/%s", svc.Namespace, svc.Name) + if seenSvcs.Contains(nsname) { + continue + } + + reqs = append(reqs, reconcile.Request{NamespacedName: client.ObjectKeyFromObject(&svc)}) + seenSvcs.Add(nsname) + } + } + } + return reqs } } diff --git a/cmd/k8s-operator/operator_test.go b/cmd/k8s-operator/operator_test.go index 53d16fbd225f3..b775a36fb5d11 100644 --- a/cmd/k8s-operator/operator_test.go +++ b/cmd/k8s-operator/operator_test.go @@ -24,13 +24,14 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" "sigs.k8s.io/controller-runtime/pkg/reconcile" + "tailscale.com/k8s-operator/apis/v1alpha1" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" + "tailscale.com/k8s-operator/tsclient" "tailscale.com/kube/kubetypes" "tailscale.com/net/dns/resolvconffile" "tailscale.com/tstest" "tailscale.com/tstime" - "tailscale.com/types/ptr" "tailscale.com/util/dnsname" "tailscale.com/util/mak" ) @@ -44,7 +45,7 @@ func TestLoadBalancerClass(t *testing.T) { Client: fc, ssr: &tailscaleSTSReconciler{ Client: fc, - tsClient: ft, + clients: tsclient.NewProvider(ft), defaultTags: []string{"tag:k8s"}, operatorNamespace: "operator-ns", proxyImage: "tailscale/tailscale", @@ -63,7 +64,7 @@ func TestLoadBalancerClass(t *testing.T) { // The apiserver is supposed to set the UID, but the fake client // doesn't. So, set it explicitly because other code later depends // on it being set. - UID: types.UID("1234-UID"), + UID: "1234-UID", Annotations: map[string]string{ AnnotationTailnetTargetFQDN: "invalid.example.com", }, @@ -71,7 +72,7 @@ func TestLoadBalancerClass(t *testing.T) { Spec: corev1.ServiceSpec{ ClusterIP: "10.20.30.40", Type: corev1.ServiceTypeLoadBalancer, - LoadBalancerClass: ptr.To("tailscale"), + LoadBalancerClass: new("tailscale"), }, }) @@ -94,7 +95,7 @@ func TestLoadBalancerClass(t *testing.T) { Spec: corev1.ServiceSpec{ ClusterIP: "10.20.30.40", Type: corev1.ServiceTypeLoadBalancer, - LoadBalancerClass: ptr.To("tailscale"), + LoadBalancerClass: new("tailscale"), }, Status: corev1.ServiceStatus{ Conditions: []metav1.Condition{{ @@ -119,7 +120,7 @@ func TestLoadBalancerClass(t *testing.T) { fullName, shortName := findGenName(t, fc, "default", "test", "svc") opts := configOpts{ - replicas: ptr.To[int32](1), + replicas: new(int32(1)), stsName: shortName, secretName: fullName, namespace: "default", @@ -204,7 +205,7 @@ func TestLoadBalancerClass(t *testing.T) { ObjectMeta: metav1.ObjectMeta{ Name: "test", Namespace: "default", - UID: types.UID("1234-UID"), + UID: "1234-UID", }, Spec: corev1.ServiceSpec{ ClusterIP: "10.20.30.40", @@ -224,7 +225,7 @@ func TestTailnetTargetFQDNAnnotation(t *testing.T) { Client: fc, ssr: &tailscaleSTSReconciler{ Client: fc, - tsClient: ft, + clients: tsclient.NewProvider(ft), defaultTags: []string{"tag:k8s"}, operatorNamespace: "operator-ns", proxyImage: "tailscale/tailscale", @@ -242,7 +243,7 @@ func TestTailnetTargetFQDNAnnotation(t *testing.T) { // The apiserver is supposed to set the UID, but the fake client // doesn't. So, set it explicitly because other code later depends // on it being set. - UID: types.UID("1234-UID"), + UID: "1234-UID", Annotations: map[string]string{ AnnotationTailnetTargetFQDN: tailnetTargetFQDN, }, @@ -259,7 +260,7 @@ func TestTailnetTargetFQDNAnnotation(t *testing.T) { fullName, shortName := findGenName(t, fc, "default", "test", "svc") o := configOpts{ - replicas: ptr.To[int32](1), + replicas: new(int32(1)), stsName: shortName, secretName: fullName, namespace: "default", @@ -334,7 +335,7 @@ func TestTailnetTargetIPAnnotation(t *testing.T) { Client: fc, ssr: &tailscaleSTSReconciler{ Client: fc, - tsClient: ft, + clients: tsclient.NewProvider(ft), defaultTags: []string{"tag:k8s"}, operatorNamespace: "operator-ns", proxyImage: "tailscale/tailscale", @@ -352,7 +353,7 @@ func TestTailnetTargetIPAnnotation(t *testing.T) { // The apiserver is supposed to set the UID, but the fake client // doesn't. So, set it explicitly because other code later depends // on it being set. - UID: types.UID("1234-UID"), + UID: "1234-UID", Annotations: map[string]string{ AnnotationTailnetTargetIP: tailnetTargetIP, }, @@ -369,7 +370,7 @@ func TestTailnetTargetIPAnnotation(t *testing.T) { fullName, shortName := findGenName(t, fc, "default", "test", "svc") o := configOpts{ - replicas: ptr.To[int32](1), + replicas: new(int32(1)), stsName: shortName, secretName: fullName, namespace: "default", @@ -443,7 +444,7 @@ func TestTailnetTargetIPAnnotation_IPCouldNotBeParsed(t *testing.T) { Client: fc, ssr: &tailscaleSTSReconciler{ Client: fc, - tsClient: ft, + clients: tsclient.NewProvider(ft), defaultTags: []string{"tag:k8s"}, operatorNamespace: "operator-ns", proxyImage: "tailscale/tailscale", @@ -458,7 +459,7 @@ func TestTailnetTargetIPAnnotation_IPCouldNotBeParsed(t *testing.T) { Name: "test", Namespace: "default", - UID: types.UID("1234-UID"), + UID: "1234-UID", Annotations: map[string]string{ AnnotationTailnetTargetIP: tailnetTargetIP, }, @@ -466,7 +467,7 @@ func TestTailnetTargetIPAnnotation_IPCouldNotBeParsed(t *testing.T) { Spec: corev1.ServiceSpec{ ClusterIP: "10.20.30.40", Type: corev1.ServiceTypeLoadBalancer, - LoadBalancerClass: ptr.To("tailscale"), + LoadBalancerClass: new("tailscale"), }, }) @@ -486,7 +487,7 @@ func TestTailnetTargetIPAnnotation_IPCouldNotBeParsed(t *testing.T) { Spec: corev1.ServiceSpec{ ClusterIP: "10.20.30.40", Type: corev1.ServiceTypeLoadBalancer, - LoadBalancerClass: ptr.To("tailscale"), + LoadBalancerClass: new("tailscale"), }, Status: corev1.ServiceStatus{ Conditions: []metav1.Condition{{ @@ -511,7 +512,7 @@ func TestTailnetTargetIPAnnotation_InvalidIP(t *testing.T) { Client: fc, ssr: &tailscaleSTSReconciler{ Client: fc, - tsClient: ft, + clients: tsclient.NewProvider(ft), defaultTags: []string{"tag:k8s"}, operatorNamespace: "operator-ns", proxyImage: "tailscale/tailscale", @@ -526,7 +527,7 @@ func TestTailnetTargetIPAnnotation_InvalidIP(t *testing.T) { Name: "test", Namespace: "default", - UID: types.UID("1234-UID"), + UID: "1234-UID", Annotations: map[string]string{ AnnotationTailnetTargetIP: tailnetTargetIP, }, @@ -534,7 +535,7 @@ func TestTailnetTargetIPAnnotation_InvalidIP(t *testing.T) { Spec: corev1.ServiceSpec{ ClusterIP: "10.20.30.40", Type: corev1.ServiceTypeLoadBalancer, - LoadBalancerClass: ptr.To("tailscale"), + LoadBalancerClass: new("tailscale"), }, }) @@ -554,7 +555,7 @@ func TestTailnetTargetIPAnnotation_InvalidIP(t *testing.T) { Spec: corev1.ServiceSpec{ ClusterIP: "10.20.30.40", Type: corev1.ServiceTypeLoadBalancer, - LoadBalancerClass: ptr.To("tailscale"), + LoadBalancerClass: new("tailscale"), }, Status: corev1.ServiceStatus{ Conditions: []metav1.Condition{{ @@ -579,7 +580,7 @@ func TestAnnotations(t *testing.T) { Client: fc, ssr: &tailscaleSTSReconciler{ Client: fc, - tsClient: ft, + clients: tsclient.NewProvider(ft), defaultTags: []string{"tag:k8s"}, operatorNamespace: "operator-ns", proxyImage: "tailscale/tailscale", @@ -597,7 +598,7 @@ func TestAnnotations(t *testing.T) { // The apiserver is supposed to set the UID, but the fake client // doesn't. So, set it explicitly because other code later depends // on it being set. - UID: types.UID("1234-UID"), + UID: "1234-UID", Annotations: map[string]string{ "tailscale.com/expose": "true", }, @@ -612,7 +613,7 @@ func TestAnnotations(t *testing.T) { fullName, shortName := findGenName(t, fc, "default", "test", "svc") o := configOpts{ - replicas: ptr.To[int32](1), + replicas: new(int32(1)), stsName: shortName, secretName: fullName, namespace: "default", @@ -664,7 +665,7 @@ func TestAnnotations(t *testing.T) { ObjectMeta: metav1.ObjectMeta{ Name: "test", Namespace: "default", - UID: types.UID("1234-UID"), + UID: "1234-UID", }, Spec: corev1.ServiceSpec{ ClusterIP: "10.20.30.40", @@ -683,7 +684,7 @@ func TestAnnotationIntoLB(t *testing.T) { Client: fc, ssr: &tailscaleSTSReconciler{ Client: fc, - tsClient: ft, + clients: tsclient.NewProvider(ft), defaultTags: []string{"tag:k8s"}, operatorNamespace: "operator-ns", proxyImage: "tailscale/tailscale", @@ -701,7 +702,7 @@ func TestAnnotationIntoLB(t *testing.T) { // The apiserver is supposed to set the UID, but the fake client // doesn't. So, set it explicitly because other code later depends // on it being set. - UID: types.UID("1234-UID"), + UID: "1234-UID", Annotations: map[string]string{ "tailscale.com/expose": "true", }, @@ -716,7 +717,7 @@ func TestAnnotationIntoLB(t *testing.T) { fullName, shortName := findGenName(t, fc, "default", "test", "svc") o := configOpts{ - replicas: ptr.To[int32](1), + replicas: new(int32(1)), stsName: shortName, secretName: fullName, namespace: "default", @@ -767,7 +768,7 @@ func TestAnnotationIntoLB(t *testing.T) { mustUpdate(t, fc, "default", "test", func(s *corev1.Service) { delete(s.ObjectMeta.Annotations, "tailscale.com/expose") s.Spec.Type = corev1.ServiceTypeLoadBalancer - s.Spec.LoadBalancerClass = ptr.To("tailscale") + s.Spec.LoadBalancerClass = new("tailscale") }) expectReconciled(t, sr, "default", "test") // None of the proxy machinery should have changed... @@ -780,12 +781,12 @@ func TestAnnotationIntoLB(t *testing.T) { Name: "test", Namespace: "default", Finalizers: []string{"tailscale.com/finalizer"}, - UID: types.UID("1234-UID"), + UID: "1234-UID", }, Spec: corev1.ServiceSpec{ ClusterIP: "10.20.30.40", Type: corev1.ServiceTypeLoadBalancer, - LoadBalancerClass: ptr.To("tailscale"), + LoadBalancerClass: new("tailscale"), }, Status: corev1.ServiceStatus{ LoadBalancer: corev1.LoadBalancerStatus{ @@ -813,7 +814,7 @@ func TestLBIntoAnnotation(t *testing.T) { Client: fc, ssr: &tailscaleSTSReconciler{ Client: fc, - tsClient: ft, + clients: tsclient.NewProvider(ft), defaultTags: []string{"tag:k8s"}, operatorNamespace: "operator-ns", proxyImage: "tailscale/tailscale", @@ -831,12 +832,12 @@ func TestLBIntoAnnotation(t *testing.T) { // The apiserver is supposed to set the UID, but the fake client // doesn't. So, set it explicitly because other code later depends // on it being set. - UID: types.UID("1234-UID"), + UID: "1234-UID", }, Spec: corev1.ServiceSpec{ ClusterIP: "10.20.30.40", Type: corev1.ServiceTypeLoadBalancer, - LoadBalancerClass: ptr.To("tailscale"), + LoadBalancerClass: new("tailscale"), }, }) @@ -844,7 +845,7 @@ func TestLBIntoAnnotation(t *testing.T) { fullName, shortName := findGenName(t, fc, "default", "test", "svc") o := configOpts{ - replicas: ptr.To[int32](1), + replicas: new(int32(1)), stsName: shortName, secretName: fullName, namespace: "default", @@ -880,7 +881,7 @@ func TestLBIntoAnnotation(t *testing.T) { Spec: corev1.ServiceSpec{ ClusterIP: "10.20.30.40", Type: corev1.ServiceTypeLoadBalancer, - LoadBalancerClass: ptr.To("tailscale"), + LoadBalancerClass: new("tailscale"), }, Status: corev1.ServiceStatus{ LoadBalancer: corev1.LoadBalancerStatus{ @@ -926,7 +927,7 @@ func TestLBIntoAnnotation(t *testing.T) { Annotations: map[string]string{ "tailscale.com/expose": "true", }, - UID: types.UID("1234-UID"), + UID: "1234-UID", }, Spec: corev1.ServiceSpec{ ClusterIP: "10.20.30.40", @@ -948,7 +949,7 @@ func TestCustomHostname(t *testing.T) { Client: fc, ssr: &tailscaleSTSReconciler{ Client: fc, - tsClient: ft, + clients: tsclient.NewProvider(ft), defaultTags: []string{"tag:k8s"}, operatorNamespace: "operator-ns", proxyImage: "tailscale/tailscale", @@ -966,7 +967,7 @@ func TestCustomHostname(t *testing.T) { // The apiserver is supposed to set the UID, but the fake client // doesn't. So, set it explicitly because other code later depends // on it being set. - UID: types.UID("1234-UID"), + UID: "1234-UID", Annotations: map[string]string{ "tailscale.com/expose": "true", "tailscale.com/hostname": "reindeer-flotilla", @@ -982,7 +983,7 @@ func TestCustomHostname(t *testing.T) { fullName, shortName := findGenName(t, fc, "default", "test", "svc") o := configOpts{ - replicas: ptr.To[int32](1), + replicas: new(int32(1)), stsName: shortName, secretName: fullName, namespace: "default", @@ -1035,7 +1036,7 @@ func TestCustomHostname(t *testing.T) { ObjectMeta: metav1.ObjectMeta{ Name: "test", Namespace: "default", - UID: types.UID("1234-UID"), + UID: "1234-UID", Annotations: map[string]string{ "tailscale.com/hostname": "reindeer-flotilla", }, @@ -1057,7 +1058,7 @@ func TestCustomPriorityClassName(t *testing.T) { Client: fc, ssr: &tailscaleSTSReconciler{ Client: fc, - tsClient: ft, + clients: tsclient.NewProvider(ft), defaultTags: []string{"tag:k8s"}, operatorNamespace: "operator-ns", proxyImage: "tailscale/tailscale", @@ -1076,7 +1077,7 @@ func TestCustomPriorityClassName(t *testing.T) { // The apiserver is supposed to set the UID, but the fake client // doesn't. So, set it explicitly because other code later depends // on it being set. - UID: types.UID("1234-UID"), + UID: "1234-UID", Annotations: map[string]string{ "tailscale.com/expose": "true", "tailscale.com/hostname": "tailscale-critical", @@ -1092,7 +1093,7 @@ func TestCustomPriorityClassName(t *testing.T) { fullName, shortName := findGenName(t, fc, "default", "test", "svc") o := configOpts{ - replicas: ptr.To[int32](1), + replicas: new(int32(1)), stsName: shortName, secretName: fullName, namespace: "default", @@ -1213,7 +1214,7 @@ func TestServiceProxyClassAnnotation(t *testing.T) { Client: fc, ssr: &tailscaleSTSReconciler{ Client: fc, - tsClient: ft, + clients: tsclient.NewProvider(ft), defaultTags: []string{"tag:k8s"}, operatorNamespace: "operator-ns", proxyImage: "tailscale/tailscale", @@ -1309,7 +1310,7 @@ func TestProxyClassForService(t *testing.T) { Client: fc, ssr: &tailscaleSTSReconciler{ Client: fc, - tsClient: ft, + clients: tsclient.NewProvider(ft), defaultTags: []string{"tag:k8s"}, operatorNamespace: "operator-ns", proxyImage: "tailscale/tailscale", @@ -1327,18 +1328,18 @@ func TestProxyClassForService(t *testing.T) { // The apiserver is supposed to set the UID, but the fake client // doesn't. So, set it explicitly because other code later depends // on it being set. - UID: types.UID("1234-UID"), + UID: "1234-UID", }, Spec: corev1.ServiceSpec{ ClusterIP: "10.20.30.40", Type: corev1.ServiceTypeLoadBalancer, - LoadBalancerClass: ptr.To("tailscale"), + LoadBalancerClass: new("tailscale"), }, }) expectReconciled(t, sr, "default", "test") fullName, shortName := findGenName(t, fc, "default", "test", "svc") opts := configOpts{ - replicas: ptr.To[int32](1), + replicas: new(int32(1)), stsName: shortName, secretName: fullName, namespace: "default", @@ -1398,7 +1399,7 @@ func TestDefaultLoadBalancer(t *testing.T) { Client: fc, ssr: &tailscaleSTSReconciler{ Client: fc, - tsClient: ft, + clients: tsclient.NewProvider(ft), defaultTags: []string{"tag:k8s"}, operatorNamespace: "operator-ns", proxyImage: "tailscale/tailscale", @@ -1417,7 +1418,7 @@ func TestDefaultLoadBalancer(t *testing.T) { // The apiserver is supposed to set the UID, but the fake client // doesn't. So, set it explicitly because other code later depends // on it being set. - UID: types.UID("1234-UID"), + UID: "1234-UID", }, Spec: corev1.ServiceSpec{ ClusterIP: "10.20.30.40", @@ -1431,7 +1432,7 @@ func TestDefaultLoadBalancer(t *testing.T) { expectEqual(t, fc, expectedHeadlessService(shortName, "svc")) o := configOpts{ - replicas: ptr.To[int32](1), + replicas: new(int32(1)), stsName: shortName, secretName: fullName, namespace: "default", @@ -1452,7 +1453,7 @@ func TestProxyFirewallMode(t *testing.T) { Client: fc, ssr: &tailscaleSTSReconciler{ Client: fc, - tsClient: ft, + clients: tsclient.NewProvider(ft), defaultTags: []string{"tag:k8s"}, operatorNamespace: "operator-ns", proxyImage: "tailscale/tailscale", @@ -1472,7 +1473,7 @@ func TestProxyFirewallMode(t *testing.T) { // The apiserver is supposed to set the UID, but the fake client // doesn't. So, set it explicitly because other code later depends // on it being set. - UID: types.UID("1234-UID"), + UID: "1234-UID", }, Spec: corev1.ServiceSpec{ ClusterIP: "10.20.30.40", @@ -1484,7 +1485,7 @@ func TestProxyFirewallMode(t *testing.T) { fullName, shortName := findGenName(t, fc, "default", "test", "svc") o := configOpts{ - replicas: ptr.To[int32](1), + replicas: new(int32(1)), stsName: shortName, secretName: fullName, namespace: "default", @@ -1499,24 +1500,28 @@ func TestProxyFirewallMode(t *testing.T) { func Test_isMagicDNSName(t *testing.T) { tests := []struct { + name string in string want bool }{ { + name: "foo-tail4567-ts-net", in: "foo.tail4567.ts.net", want: true, }, { + name: "foo-tail4567-ts-net-trailing-dot", in: "foo.tail4567.ts.net.", want: true, }, { + name: "foo-tail4567", in: "foo.tail4567", want: false, }, } for _, tt := range tests { - t.Run(tt.in, func(t *testing.T) { + t.Run(tt.name, func(t *testing.T) { if got := isMagicDNSName(tt.in); got != tt.want { t.Errorf("isMagicDNSName(%q) = %v, want %v", tt.in, got, tt.want) } @@ -1542,7 +1547,7 @@ func Test_HeadlessService(t *testing.T) { Name: "test", Namespace: "default", - UID: types.UID("1234-UID"), + UID: "1234-UID", Annotations: map[string]string{ AnnotationExpose: "true", }, @@ -1596,7 +1601,7 @@ func Test_serviceHandlerForIngress(t *testing.T) { Name: "ing-1", Namespace: "ns-1", }, - Spec: networkingv1.IngressSpec{IngressClassName: ptr.To(tailscaleIngressClassName)}, + Spec: networkingv1.IngressSpec{IngressClassName: new(tailscaleIngressClassName)}, }) svc1 := &corev1.Service{ ObjectMeta: metav1.ObjectMeta{ @@ -1628,7 +1633,7 @@ func Test_serviceHandlerForIngress(t *testing.T) { DefaultBackend: &networkingv1.IngressBackend{ Service: &networkingv1.IngressServiceBackend{Name: "def-backend"}, }, - IngressClassName: ptr.To(tailscaleIngressClassName), + IngressClassName: new(tailscaleIngressClassName), }, }) backendSvc := &corev1.Service{ @@ -1652,7 +1657,7 @@ func Test_serviceHandlerForIngress(t *testing.T) { Namespace: "ns-3", }, Spec: networkingv1.IngressSpec{ - IngressClassName: ptr.To(tailscaleIngressClassName), + IngressClassName: new(tailscaleIngressClassName), Rules: []networkingv1.IngressRule{{IngressRuleValue: networkingv1.IngressRuleValue{HTTP: &networkingv1.HTTPIngressRuleValue{ Paths: []networkingv1.HTTPIngressPath{ {Backend: networkingv1.IngressBackend{Service: &networkingv1.IngressServiceBackend{Name: "backend"}}}, @@ -1727,7 +1732,7 @@ func Test_serviceHandlerForIngress_multipleIngressClasses(t *testing.T) { mustCreate(t, fc, &networkingv1.Ingress{ ObjectMeta: metav1.ObjectMeta{Name: "nginx-ing", Namespace: "default"}, Spec: networkingv1.IngressSpec{ - IngressClassName: ptr.To("nginx"), + IngressClassName: new("nginx"), DefaultBackend: &networkingv1.IngressBackend{Service: &networkingv1.IngressServiceBackend{Name: "backend"}}, }, }) @@ -1735,7 +1740,7 @@ func Test_serviceHandlerForIngress_multipleIngressClasses(t *testing.T) { mustCreate(t, fc, &networkingv1.Ingress{ ObjectMeta: metav1.ObjectMeta{Name: "ts-ing", Namespace: "default"}, Spec: networkingv1.IngressSpec{ - IngressClassName: ptr.To("tailscale"), + IngressClassName: new("tailscale"), DefaultBackend: &networkingv1.IngressBackend{Service: &networkingv1.IngressServiceBackend{Name: "backend"}}, }, }) @@ -1757,7 +1762,7 @@ func Test_clusterDomainFromResolverConf(t *testing.T) { want string }{ { - name: "success- custom domain", + name: "success-custom-domain", conf: &resolvconffile.Config{ SearchDomains: []dnsname.FQDN{toFQDN(t, "foo.svc.department.org.io"), toFQDN(t, "svc.department.org.io"), toFQDN(t, "department.org.io")}, }, @@ -1765,7 +1770,7 @@ func Test_clusterDomainFromResolverConf(t *testing.T) { want: "department.org.io", }, { - name: "success- default domain", + name: "success-default-domain", conf: &resolvconffile.Config{ SearchDomains: []dnsname.FQDN{toFQDN(t, "foo.svc.cluster.local."), toFQDN(t, "svc.cluster.local."), toFQDN(t, "cluster.local.")}, }, @@ -1773,7 +1778,7 @@ func Test_clusterDomainFromResolverConf(t *testing.T) { want: "cluster.local", }, { - name: "only two search domains found", + name: "only-two-search-domains", conf: &resolvconffile.Config{ SearchDomains: []dnsname.FQDN{toFQDN(t, "svc.department.org.io"), toFQDN(t, "department.org.io")}, }, @@ -1781,7 +1786,7 @@ func Test_clusterDomainFromResolverConf(t *testing.T) { want: "cluster.local", }, { - name: "first search domain does not match the expected structure", + name: "first-search-domain-mismatch", conf: &resolvconffile.Config{ SearchDomains: []dnsname.FQDN{toFQDN(t, "foo.bar.department.org.io"), toFQDN(t, "svc.department.org.io"), toFQDN(t, "some.other.fqdn")}, }, @@ -1789,7 +1794,7 @@ func Test_clusterDomainFromResolverConf(t *testing.T) { want: "cluster.local", }, { - name: "second search domain does not match the expected structure", + name: "second-search-domain-mismatch", conf: &resolvconffile.Config{ SearchDomains: []dnsname.FQDN{toFQDN(t, "foo.svc.department.org.io"), toFQDN(t, "foo.department.org.io"), toFQDN(t, "some.other.fqdn")}, }, @@ -1797,7 +1802,7 @@ func Test_clusterDomainFromResolverConf(t *testing.T) { want: "cluster.local", }, { - name: "third search domain does not match the expected structure", + name: "third-search-domain-mismatch", conf: &resolvconffile.Config{ SearchDomains: []dnsname.FQDN{toFQDN(t, "foo.svc.department.org.io"), toFQDN(t, "svc.department.org.io"), toFQDN(t, "some.other.fqdn")}, }, @@ -1826,7 +1831,7 @@ func Test_authKeyRemoval(t *testing.T) { Client: fc, ssr: &tailscaleSTSReconciler{ Client: fc, - tsClient: ft, + clients: tsclient.NewProvider(ft), defaultTags: []string{"tag:k8s"}, operatorNamespace: "operator-ns", proxyImage: "tailscale/tailscale", @@ -1839,12 +1844,12 @@ func Test_authKeyRemoval(t *testing.T) { ObjectMeta: metav1.ObjectMeta{ Name: "test", Namespace: "default", - UID: types.UID("1234-UID"), + UID: "1234-UID", }, Spec: corev1.ServiceSpec{ ClusterIP: "10.20.30.40", Type: corev1.ServiceTypeLoadBalancer, - LoadBalancerClass: ptr.To("tailscale"), + LoadBalancerClass: new("tailscale"), }, }) @@ -1859,7 +1864,7 @@ func Test_authKeyRemoval(t *testing.T) { hostname: "default-test", clusterTargetIP: "10.20.30.40", app: kubetypes.AppIngressProxy, - replicas: ptr.To[int32](1), + replicas: new(int32(1)), } expectEqual(t, fc, expectedSecret(t, fc, opts)) @@ -1891,7 +1896,7 @@ func Test_externalNameService(t *testing.T) { Client: fc, ssr: &tailscaleSTSReconciler{ Client: fc, - tsClient: ft, + clients: tsclient.NewProvider(ft), defaultTags: []string{"tag:k8s"}, operatorNamespace: "operator-ns", proxyImage: "tailscale/tailscale", @@ -1909,7 +1914,7 @@ func Test_externalNameService(t *testing.T) { // The apiserver is supposed to set the UID, but the fake client // doesn't. So, set it explicitly because other code later depends // on it being set. - UID: types.UID("1234-UID"), + UID: "1234-UID", Annotations: map[string]string{ AnnotationExpose: "true", }, @@ -1924,7 +1929,7 @@ func Test_externalNameService(t *testing.T) { fullName, shortName := findGenName(t, fc, "default", "test", "svc") opts := configOpts{ - replicas: ptr.To[int32](1), + replicas: new(int32(1)), stsName: shortName, secretName: fullName, namespace: "default", @@ -1969,7 +1974,7 @@ func Test_metricsResourceCreation(t *testing.T) { Spec: corev1.ServiceSpec{ ClusterIP: "10.20.30.40", Type: corev1.ServiceTypeLoadBalancer, - LoadBalancerClass: ptr.To("tailscale"), + LoadBalancerClass: new("tailscale"), }, } crd := &apiextensionsv1.CustomResourceDefinition{ObjectMeta: metav1.ObjectMeta{Name: serviceMonitorCRD}} @@ -1985,7 +1990,7 @@ func Test_metricsResourceCreation(t *testing.T) { Client: fc, ssr: &tailscaleSTSReconciler{ Client: fc, - tsClient: ft, + clients: tsclient.NewProvider(ft), operatorNamespace: "operator-ns", }, logger: zl.Sugar(), @@ -2056,7 +2061,7 @@ func TestIgnorePGService(t *testing.T) { Client: fc, ssr: &tailscaleSTSReconciler{ Client: fc, - tsClient: ft, + clients: tsclient.NewProvider(ft), defaultTags: []string{"tag:k8s"}, operatorNamespace: "operator-ns", proxyImage: "tailscale/tailscale", @@ -2074,7 +2079,7 @@ func TestIgnorePGService(t *testing.T) { // The apiserver is supposed to set the UID, but the fake client // doesn't. So, set it explicitly because other code later depends // on it being set. - UID: types.UID("1234-UID"), + UID: "1234-UID", Annotations: map[string]string{ "tailscale.com/proxygroup": "test-pg", }, diff --git a/cmd/k8s-operator/proxygroup_specs.go b/cmd/k8s-operator/proxygroup_specs.go index 05e0ed0b26013..60b4bddd5613c 100644 --- a/cmd/k8s-operator/proxygroup_specs.go +++ b/cmd/k8s-operator/proxygroup_specs.go @@ -7,6 +7,7 @@ package main import ( "fmt" + "maps" "slices" "strconv" "strings" @@ -22,7 +23,6 @@ import ( "tailscale.com/kube/egressservices" "tailscale.com/kube/ingressservices" "tailscale.com/kube/kubetypes" - "tailscale.com/types/ptr" ) const ( @@ -87,7 +87,7 @@ func pgStatefulSet(pg *tsapi.ProxyGroup, namespace, image, tsFirewallMode string Labels: pgLabels(pg.Name, nil), OwnerReferences: pgOwnerReference(pg), } - ss.Spec.Replicas = ptr.To(pgReplicas(pg)) + ss.Spec.Replicas = new(pgReplicas(pg)) ss.Spec.Selector = &metav1.LabelSelector{ MatchLabels: pgLabels(pg.Name, nil), } @@ -98,7 +98,7 @@ func pgStatefulSet(pg *tsapi.ProxyGroup, namespace, image, tsFirewallMode string Name: pg.Name, Namespace: namespace, Labels: pgLabels(pg.Name, nil), - DeletionGracePeriodSeconds: ptr.To[int64](10), + DeletionGracePeriodSeconds: new(int64(10)), } tmpl.Spec.ServiceAccountName = pg.Name tmpl.Spec.InitContainers[0].Image = image @@ -282,7 +282,7 @@ func pgStatefulSet(pg *tsapi.ProxyGroup, namespace, image, tsFirewallMode string } // Set the deletion grace period to 6 minutes to ensure that the pre-stop hook has enough time to terminate // gracefully. - ss.Spec.Template.DeletionGracePeriodSeconds = ptr.To(deletionGracePeriodSeconds) + ss.Spec.Template.DeletionGracePeriodSeconds = new(deletionGracePeriodSeconds) } return ss, nil @@ -297,7 +297,7 @@ func kubeAPIServerStatefulSet(pg *tsapi.ProxyGroup, namespace, image string, por OwnerReferences: pgOwnerReference(pg), }, Spec: appsv1.StatefulSetSpec{ - Replicas: ptr.To(pgReplicas(pg)), + Replicas: new(pgReplicas(pg)), Selector: &metav1.LabelSelector{ MatchLabels: pgLabels(pg.Name, nil), }, @@ -306,7 +306,7 @@ func kubeAPIServerStatefulSet(pg *tsapi.ProxyGroup, namespace, image string, por Name: pg.Name, Namespace: namespace, Labels: pgLabels(pg.Name, nil), - DeletionGracePeriodSeconds: ptr.To[int64](10), + DeletionGracePeriodSeconds: new(int64(10)), }, Spec: corev1.PodSpec{ ServiceAccountName: pgServiceAccountName(pg), @@ -545,9 +545,7 @@ func pgSecretLabels(pgName, secretType string) map[string]string { func pgLabels(pgName string, customLabels map[string]string) map[string]string { labels := make(map[string]string, len(customLabels)+3) - for k, v := range customLabels { - labels[k] = v - } + maps.Copy(labels, customLabels) labels[kubetypes.LabelManaged] = "true" labels[LabelParentType] = "proxygroup" diff --git a/cmd/k8s-operator/sts_test.go b/cmd/k8s-operator/sts_test.go index 81c0d25ec0ba4..f55f582a6de42 100644 --- a/cmd/k8s-operator/sts_test.go +++ b/cmd/k8s-operator/sts_test.go @@ -8,6 +8,7 @@ package main import ( _ "embed" "fmt" + "maps" "reflect" "regexp" "strings" @@ -22,7 +23,6 @@ import ( "sigs.k8s.io/yaml" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" "tailscale.com/kube/kubetypes" - "tailscale.com/types/ptr" ) // Test_statefulSetNameBase tests that parent name portion in a StatefulSet name @@ -69,7 +69,7 @@ func Test_applyProxyClassToStatefulSet(t *testing.T) { Labels: tsapi.Labels{"bar": "foo"}, Annotations: map[string]string{"bar.io/foo": "foo"}, SecurityContext: &corev1.PodSecurityContext{ - RunAsUser: ptr.To(int64(0)), + RunAsUser: new(int64(0)), }, ImagePullSecrets: []corev1.LocalObjectReference{{Name: "docker-creds"}}, NodeName: "some-node", @@ -87,18 +87,18 @@ func Test_applyProxyClassToStatefulSet(t *testing.T) { }, }, }, - DNSPolicy: ptr.To(corev1.DNSClusterFirstWithHostNet), + DNSPolicy: new(corev1.DNSClusterFirstWithHostNet), DNSConfig: &corev1.PodDNSConfig{ Nameservers: []string{"1.1.1.1", "8.8.8.8"}, Searches: []string{"example.com", "test.local"}, Options: []corev1.PodDNSConfigOption{ - {Name: "ndots", Value: ptr.To("2")}, + {Name: "ndots", Value: new("2")}, {Name: "edns0"}, }, }, TailscaleContainer: &tsapi.Container{ SecurityContext: &corev1.SecurityContext{ - Privileged: ptr.To(true), + Privileged: new(true), }, Resources: corev1.ResourceRequirements{ Limits: corev1.ResourceList{corev1.ResourceCPU: resource.MustParse("1000m"), corev1.ResourceMemory: resource.MustParse("128Mi")}, @@ -110,8 +110,8 @@ func Test_applyProxyClassToStatefulSet(t *testing.T) { }, TailscaleInitContainer: &tsapi.Container{ SecurityContext: &corev1.SecurityContext{ - Privileged: ptr.To(true), - RunAsUser: ptr.To(int64(0)), + Privileged: new(true), + RunAsUser: new(int64(0)), }, Resources: corev1.ResourceRequirements{ Limits: corev1.ResourceList{corev1.ResourceCPU: resource.MustParse("1000m"), corev1.ResourceMemory: resource.MustParse("128Mi")}, @@ -293,7 +293,7 @@ func Test_applyProxyClassToStatefulSet(t *testing.T) { corev1.EnvVar{Name: "TS_ENABLE_METRICS", Value: "true"}, ) wantSS.Spec.Template.Spec.Containers[0].Ports = []corev1.ContainerPort{{Name: "metrics", Protocol: "TCP", ContainerPort: 9002}} - gotSS = applyProxyClassToStatefulSet(proxyClassWithMetricsDebug(true, ptr.To(false)), nonUserspaceProxySS.DeepCopy(), new(tailscaleSTSConfig), zl.Sugar()) + gotSS = applyProxyClassToStatefulSet(proxyClassWithMetricsDebug(true, new(false)), nonUserspaceProxySS.DeepCopy(), new(tailscaleSTSConfig), zl.Sugar()) if diff := cmp.Diff(gotSS, wantSS); diff != "" { t.Errorf("Unexpected result applying ProxyClass with metrics enabled to a StatefulSet (-got +want):\n%s", diff) } @@ -305,7 +305,7 @@ func Test_applyProxyClassToStatefulSet(t *testing.T) { corev1.EnvVar{Name: "TS_TAILSCALED_EXTRA_ARGS", Value: "--debug=$(TS_DEBUG_ADDR_PORT)"}, ) wantSS.Spec.Template.Spec.Containers[0].Ports = []corev1.ContainerPort{{Name: "debug", Protocol: "TCP", ContainerPort: 9001}} - gotSS = applyProxyClassToStatefulSet(proxyClassWithMetricsDebug(false, ptr.To(true)), nonUserspaceProxySS.DeepCopy(), new(tailscaleSTSConfig), zl.Sugar()) + gotSS = applyProxyClassToStatefulSet(proxyClassWithMetricsDebug(false, new(true)), nonUserspaceProxySS.DeepCopy(), new(tailscaleSTSConfig), zl.Sugar()) if diff := cmp.Diff(gotSS, wantSS); diff != "" { t.Errorf("Unexpected result applying ProxyClass with metrics enabled to a StatefulSet (-got +want):\n%s", diff) } @@ -324,76 +324,76 @@ func Test_mergeStatefulSetLabelsOrAnnots(t *testing.T) { want map[string]string }{ { - name: "no custom labels specified and none present in current labels, return current labels", + name: "no-custom-labels-none-present", current: map[string]string{kubetypes.LabelManaged: "true", LabelParentName: "foo", LabelParentType: "svc", LabelParentNamespace: "foo"}, want: map[string]string{kubetypes.LabelManaged: "true", LabelParentName: "foo", LabelParentType: "svc", LabelParentNamespace: "foo"}, managed: tailscaleManagedLabels, }, { - name: "no custom labels specified, but some present in current labels, return tailscale managed labels only from the current labels", + name: "no-custom-labels-some-present", current: map[string]string{"foo": "bar", "something.io/foo": "bar", kubetypes.LabelManaged: "true", LabelParentName: "foo", LabelParentType: "svc", LabelParentNamespace: "foo"}, want: map[string]string{kubetypes.LabelManaged: "true", LabelParentName: "foo", LabelParentType: "svc", LabelParentNamespace: "foo"}, managed: tailscaleManagedLabels, }, { - name: "custom labels specified, current labels only contain tailscale managed labels, return a union of both", + name: "custom-labels-with-managed-only", current: map[string]string{kubetypes.LabelManaged: "true", LabelParentName: "foo", LabelParentType: "svc", LabelParentNamespace: "foo"}, custom: map[string]string{"foo": "bar", "something.io/foo": "bar"}, want: map[string]string{"foo": "bar", "something.io/foo": "bar", kubetypes.LabelManaged: "true", LabelParentName: "foo", LabelParentType: "svc", LabelParentNamespace: "foo"}, managed: tailscaleManagedLabels, }, { - name: "custom labels specified, current labels contain tailscale managed labels and custom labels, some of which re not present in the new custom labels, return a union of managed labels and the desired custom labels", + name: "custom-labels-stale-removed", current: map[string]string{"foo": "bar", "bar": "baz", "app": "1234", kubetypes.LabelManaged: "true", LabelParentName: "foo", LabelParentType: "svc", LabelParentNamespace: "foo"}, custom: map[string]string{"foo": "bar", "something.io/foo": "bar"}, want: map[string]string{"foo": "bar", "something.io/foo": "bar", "app": "1234", kubetypes.LabelManaged: "true", LabelParentName: "foo", LabelParentType: "svc", LabelParentNamespace: "foo"}, managed: tailscaleManagedLabels, }, { - name: "no current labels present, return custom labels only", + name: "no-current-labels-return-custom", custom: map[string]string{"foo": "bar", "something.io/foo": "bar"}, want: map[string]string{"foo": "bar", "something.io/foo": "bar"}, managed: tailscaleManagedLabels, }, { - name: "no current labels present, no custom labels specified, return empty map", + name: "no-current-no-custom-return-empty", want: map[string]string{}, managed: tailscaleManagedLabels, }, { - name: "no custom annots specified and none present in current annots, return current annots", + name: "no-custom-annots-none-present", current: map[string]string{podAnnotationLastSetClusterIP: "1.2.3.4"}, want: map[string]string{podAnnotationLastSetClusterIP: "1.2.3.4"}, managed: tailscaleManagedAnnotations, }, { - name: "no custom annots specified, but some present in current annots, return tailscale managed annots only from the current annots", + name: "no-custom-annots-some-present", current: map[string]string{"foo": "bar", "something.io/foo": "bar", podAnnotationLastSetClusterIP: "1.2.3.4"}, want: map[string]string{podAnnotationLastSetClusterIP: "1.2.3.4"}, managed: tailscaleManagedAnnotations, }, { - name: "custom annots specified, current annots only contain tailscale managed annots, return a union of both", + name: "custom-annots-with-managed-only", current: map[string]string{podAnnotationLastSetClusterIP: "1.2.3.4"}, custom: map[string]string{"foo": "bar", "something.io/foo": "bar"}, want: map[string]string{"foo": "bar", "something.io/foo": "bar", podAnnotationLastSetClusterIP: "1.2.3.4"}, managed: tailscaleManagedAnnotations, }, { - name: "custom annots specified, current annots contain tailscale managed annots and custom annots, some of which are not present in the new custom annots, return a union of managed annots and the desired custom annots", + name: "custom-annots-stale-removed", current: map[string]string{"foo": "bar", "something.io/foo": "bar", podAnnotationLastSetClusterIP: "1.2.3.4"}, custom: map[string]string{"something.io/foo": "bar"}, want: map[string]string{"something.io/foo": "bar", podAnnotationLastSetClusterIP: "1.2.3.4"}, managed: tailscaleManagedAnnotations, }, { - name: "no current annots present, return custom annots only", + name: "no-current-annots-return-custom", custom: map[string]string{"foo": "bar", "something.io/foo": "bar"}, want: map[string]string{"foo": "bar", "something.io/foo": "bar"}, managed: tailscaleManagedAnnotations, }, { - name: "no current labels present, no custom labels specified, return empty map", + name: "no-current-annots-no-custom-return-empty", want: map[string]string{}, managed: tailscaleManagedAnnotations, }, @@ -409,7 +409,5 @@ func Test_mergeStatefulSetLabelsOrAnnots(t *testing.T) { // updateMap updates map a with the values from map b. func updateMap(a, b map[string]string) { - for key, val := range b { - a[key] = val - } + maps.Copy(a, b) } diff --git a/cmd/k8s-operator/svc.go b/cmd/k8s-operator/svc.go index 31be22aa12ca3..6f12148c85807 100644 --- a/cmd/k8s-operator/svc.go +++ b/cmd/k8s-operator/svc.go @@ -42,8 +42,6 @@ const ( reasonProxyInvalid = "ProxyInvalid" reasonProxyFailed = "ProxyFailed" reasonProxyPending = "ProxyPending" - - indexServiceProxyClass = ".metadata.annotations.service-proxy-class" ) type ServiceReconciler struct { @@ -97,7 +95,7 @@ func childResourceLabels(name, ns, typ string) map[string]string { func (a *ServiceReconciler) isTailscaleService(svc *corev1.Service) bool { targetIP := tailnetTargetAnnotation(svc) targetFQDN := svc.Annotations[AnnotationTailnetTargetFQDN] - return a.shouldExpose(svc) || targetIP != "" || targetFQDN != "" + return shouldExpose(svc, a.isDefaultLoadBalancer) || targetIP != "" || targetFQDN != "" } func (a *ServiceReconciler) Reconcile(ctx context.Context, req reconcile.Request) (_ reconcile.Result, err error) { @@ -164,7 +162,7 @@ func (a *ServiceReconciler) maybeCleanup(ctx context.Context, logger *zap.Sugare } proxyTyp := proxyTypeEgress - if a.shouldExpose(svc) { + if shouldExpose(svc, a.isDefaultLoadBalancer) { proxyTyp = proxyTypeIngressService } @@ -275,16 +273,16 @@ func (a *ServiceReconciler) maybeProvision(ctx context.Context, logger *zap.Suga LoginServer: a.ssr.loginServer, } sts.proxyType = proxyTypeEgress - if a.shouldExpose(svc) { + if shouldExpose(svc, a.isDefaultLoadBalancer) { sts.proxyType = proxyTypeIngressService } a.mu.Lock() - if a.shouldExposeClusterIP(svc) { + if shouldExposeClusterIP(svc, a.isDefaultLoadBalancer) { sts.ClusterTargetIP = svc.Spec.ClusterIP a.managedIngressProxies.Add(svc.UID) gaugeIngressProxies.Set(int64(a.managedIngressProxies.Len())) - } else if a.shouldExposeDNSName(svc) { + } else if shouldExposeDNSName(svc) { sts.ClusterTargetDNSName = svc.Spec.ExternalName a.managedIngressProxies.Add(svc.UID) gaugeIngressProxies.Set(int64(a.managedIngressProxies.Len())) @@ -410,19 +408,19 @@ func validateService(svc *corev1.Service) []string { return violations } -func (a *ServiceReconciler) shouldExpose(svc *corev1.Service) bool { - return a.shouldExposeClusterIP(svc) || a.shouldExposeDNSName(svc) +func shouldExpose(svc *corev1.Service, isDefaultLoadBalancer bool) bool { + return shouldExposeClusterIP(svc, isDefaultLoadBalancer) || shouldExposeDNSName(svc) } -func (a *ServiceReconciler) shouldExposeDNSName(svc *corev1.Service) bool { +func shouldExposeDNSName(svc *corev1.Service) bool { return hasExposeAnnotation(svc) && svc.Spec.Type == corev1.ServiceTypeExternalName && svc.Spec.ExternalName != "" } -func (a *ServiceReconciler) shouldExposeClusterIP(svc *corev1.Service) bool { +func shouldExposeClusterIP(svc *corev1.Service, isDefaultLoadBalancer bool) bool { if svc.Spec.ClusterIP == "" { return false } - return isTailscaleLoadBalancerService(svc, a.isDefaultLoadBalancer) || hasExposeAnnotation(svc) + return isTailscaleLoadBalancerService(svc, isDefaultLoadBalancer) || hasExposeAnnotation(svc) } func isTailscaleLoadBalancerService(svc *corev1.Service, isDefaultLoadBalancer bool) bool { diff --git a/cmd/k8s-operator/svc_test.go b/cmd/k8s-operator/svc_test.go new file mode 100644 index 0000000000000..677e9db10d40d --- /dev/null +++ b/cmd/k8s-operator/svc_test.go @@ -0,0 +1,223 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "context" + "slices" + "testing" + + "go.uber.org/zap" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client/fake" + "sigs.k8s.io/controller-runtime/pkg/reconcile" + + tsapi "tailscale.com/k8s-operator/apis/v1alpha1" + "tailscale.com/k8s-operator/tsclient" + "tailscale.com/kube/kubetypes" + "tailscale.com/tstest" +) + +func TestService_DefaultProxyClassInitiallyNotReady(t *testing.T) { + pc := &tsapi.ProxyClass{ + ObjectMeta: metav1.ObjectMeta{Name: "custom-metadata"}, + Spec: tsapi.ProxyClassSpec{ + TailscaleConfig: &tsapi.TailscaleConfig{ + AcceptRoutes: true, + }, + StatefulSet: &tsapi.StatefulSet{ + Labels: tsapi.Labels{"foo": "bar"}, + Annotations: map[string]string{"bar.io/foo": "some-val"}, + Pod: &tsapi.Pod{Annotations: map[string]string{"foo.io/bar": "some-val"}}, + }, + }, + } + fc := fake.NewClientBuilder(). + WithScheme(tsapi.GlobalScheme). + WithObjects(pc). + WithStatusSubresource(pc). + Build() + ft := &fakeTSClient{} + zl := zap.Must(zap.NewDevelopment()) + clock := tstest.NewClock(tstest.ClockOpts{}) + sr := &ServiceReconciler{ + Client: fc, + ssr: &tailscaleSTSReconciler{ + Client: fc, + clients: tsclient.NewProvider(ft), + defaultTags: []string{"tag:k8s"}, + operatorNamespace: "operator-ns", + proxyImage: "tailscale/tailscale", + }, + defaultProxyClass: "custom-metadata", + logger: zl.Sugar(), + clock: clock, + } + + // 1. A new tailscale LoadBalancer Service is created but the default + // ProxyClass is not ready yet. + mustCreate(t, fc, &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + Namespace: "default", + // The apiserver is supposed to set the UID, but the fake client + // doesn't. So, set it explicitly because other code later depends + // on it being set. + UID: types.UID("1234-UID"), + }, + Spec: corev1.ServiceSpec{ + ClusterIP: "10.20.30.40", + Type: corev1.ServiceTypeLoadBalancer, + LoadBalancerClass: new("tailscale"), + }, + }) + expectReconciled(t, sr, "default", "test") + labels := map[string]string{ + kubetypes.LabelManaged: "true", + LabelParentName: "test", + LabelParentNamespace: "operator-ns", + LabelParentType: "svc", + } + s, err := getSingleObject[corev1.Secret](context.Background(), fc, "operator-ns", labels) + if err != nil { + t.Fatalf("finding Secret for %q: %v", "test", err) + } + if s != nil { + t.Fatalf("expected no Secret to be created when default ProxyClass is not ready, but found one: %v", s) + } + + // 2. ProxyClass is set to Ready, the Service can become ready now. + mustUpdateStatus(t, fc, "", "custom-metadata", func(pc *tsapi.ProxyClass) { + pc.Status = tsapi.ProxyClassStatus{ + Conditions: []metav1.Condition{{ + Status: metav1.ConditionTrue, + Type: string(tsapi.ProxyClassReady), + ObservedGeneration: pc.Generation, + }}, + } + }) + expectReconciled(t, sr, "default", "test") + fullName, shortName := findGenName(t, fc, "default", "test", "svc") + opts := configOpts{ + replicas: new(int32(1)), + stsName: shortName, + secretName: fullName, + namespace: "default", + parentType: "svc", + hostname: "default-test", + clusterTargetIP: "10.20.30.40", + app: kubetypes.AppIngressProxy, + proxyClass: pc.Name, + } + expectEqual(t, fc, expectedSecret(t, fc, opts)) + expectEqual(t, fc, expectedHeadlessService(shortName, "svc")) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeResourceReqs) +} + +func TestProxyClassHandlerForSvc(t *testing.T) { + svc := func(name string, annotations, labels map[string]string) *corev1.Service { + return &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: "default", + Annotations: annotations, + Labels: labels, + }, + Spec: corev1.ServiceSpec{ + ClusterIP: "1.2.3.4", + }, + } + } + lbSvc := func(name string, annotations map[string]string, class *string) *corev1.Service { + return &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: "foo", + Annotations: annotations, + }, + Spec: corev1.ServiceSpec{ + Type: corev1.ServiceTypeLoadBalancer, + LoadBalancerClass: class, + ClusterIP: "1.2.3.4", + }, + } + } + + const ( + defaultPCName = "default-proxyclass" + otherPCName = "other-proxyclass" + unreferencedPCName = "unreferenced-proxyclass" + ) + fc := fake.NewClientBuilder(). + WithScheme(tsapi.GlobalScheme). + WithIndex(&corev1.Service{}, indexServiceProxyClass, indexProxyClass). + WithIndex(&corev1.Service{}, indexServiceExposed, indexExposed). + WithIndex(&corev1.Service{}, indexServiceType, indexType). + WithObjects( + svc("not-exposed", nil, nil), + svc("exposed-default", map[string]string{AnnotationExpose: "true"}, nil), + svc("exposed-other", map[string]string{AnnotationExpose: "true", LabelAnnotationProxyClass: otherPCName}, nil), + svc("annotated", map[string]string{LabelAnnotationProxyClass: defaultPCName}, nil), + svc("labelled", nil, map[string]string{LabelAnnotationProxyClass: defaultPCName}), + lbSvc("lb-svc", nil, new("tailscale")), + lbSvc("lb-svc-no-class", nil, nil), + lbSvc("lb-svc-other-class", nil, new("other")), + lbSvc("lb-svc-other-pc", map[string]string{LabelAnnotationProxyClass: otherPCName}, nil), + ). + Build() + + zl := zap.Must(zap.NewDevelopment()) + mapFunc := proxyClassHandlerForSvc(fc, zl.Sugar(), defaultPCName, true) + + for _, tc := range []struct { + name string + proxyClassName string + expected []reconcile.Request + }{ + { + name: "default_ProxyClass", + proxyClassName: defaultPCName, + expected: []reconcile.Request{ + {NamespacedName: types.NamespacedName{Namespace: "default", Name: "exposed-default"}}, + {NamespacedName: types.NamespacedName{Namespace: "default", Name: "annotated"}}, + {NamespacedName: types.NamespacedName{Namespace: "default", Name: "labelled"}}, + {NamespacedName: types.NamespacedName{Namespace: "foo", Name: "lb-svc"}}, + {NamespacedName: types.NamespacedName{Namespace: "foo", Name: "lb-svc-no-class"}}, + }, + }, + { + name: "other_ProxyClass", + proxyClassName: otherPCName, + expected: []reconcile.Request{ + {NamespacedName: types.NamespacedName{Namespace: "default", Name: "exposed-other"}}, + {NamespacedName: types.NamespacedName{Namespace: "foo", Name: "lb-svc-other-pc"}}, + }, + }, + { + name: "unreferenced_ProxyClass", + proxyClassName: unreferencedPCName, + expected: nil, + }, + } { + t.Run(tc.name, func(t *testing.T) { + reqs := mapFunc(t.Context(), &tsapi.ProxyClass{ + ObjectMeta: metav1.ObjectMeta{ + Name: tc.proxyClassName, + }, + }) + if len(reqs) != len(tc.expected) { + t.Fatalf("expected %d requests, got %d: %v", len(tc.expected), len(reqs), reqs) + } + for _, expected := range tc.expected { + if !slices.Contains(reqs, expected) { + t.Errorf("expected request for Service %q not found in results: %v", expected.Name, reqs) + } + } + }) + } +} diff --git a/cmd/k8s-operator/tsclient.go b/cmd/k8s-operator/tsclient.go index 063c2f768c6c6..0670d2bcf94a0 100644 --- a/cmd/k8s-operator/tsclient.go +++ b/cmd/k8s-operator/tsclient.go @@ -6,26 +6,18 @@ package main import ( - "context" "fmt" - "net/http" + "net/url" "os" - "sync" - "time" "go.uber.org/zap" - "golang.org/x/oauth2" - "golang.org/x/oauth2/clientcredentials" - "tailscale.com/internal/client/tailscale" + "tailscale.com/client/tailscale/v2" + "tailscale.com/ipn" - "tailscale.com/tailcfg" ) -// defaultTailnet is a value that can be used in Tailscale API calls instead of tailnet name to indicate that the API -// call should be performed on the default tailnet for the provided credentials. const ( - defaultTailnet = "-" - oidcJWTPath = "/var/run/secrets/tailscale/serviceaccount/token" + oidcJWTPath = "/var/run/secrets/tailscale/serviceaccount/token" ) func newTSClient(logger *zap.SugaredLogger, clientID, clientIDPath, clientSecretPath, loginServer string) (*tailscale.Client, error) { @@ -34,100 +26,45 @@ func newTSClient(logger *zap.SugaredLogger, clientID, clientIDPath, clientSecret baseURL = loginServer } - var httpClient *http.Client + base, err := url.Parse(baseURL) + if err != nil { + return nil, err + } + + client := &tailscale.Client{ + UserAgent: "tailscale-k8s-operator", + BaseURL: base, + } + if clientID == "" { // Use static client credentials mounted to disk. - id, err := os.ReadFile(clientIDPath) + clientIDBytes, err := os.ReadFile(clientIDPath) if err != nil { return nil, fmt.Errorf("error reading client ID %q: %w", clientIDPath, err) } - secret, err := os.ReadFile(clientSecretPath) + clientSecretBytes, err := os.ReadFile(clientSecretPath) if err != nil { return nil, fmt.Errorf("reading client secret %q: %w", clientSecretPath, err) } - credentials := clientcredentials.Config{ - ClientID: string(id), - ClientSecret: string(secret), - TokenURL: fmt.Sprintf("%s%s", baseURL, "/api/v2/oauth/token"), + + client.Auth = &tailscale.OAuth{ + ClientID: string(clientIDBytes), + ClientSecret: string(clientSecretBytes), } - tokenSrc := credentials.TokenSource(context.Background()) - httpClient = oauth2.NewClient(context.Background(), tokenSrc) } else { // Use workload identity federation. - tokenSrc := &jwtTokenSource{ - logger: logger, - jwtPath: oidcJWTPath, - baseCfg: clientcredentials.Config{ - ClientID: clientID, - TokenURL: fmt.Sprintf("%s%s", baseURL, "/api/v2/oauth/token-exchange"), + client.Auth = &tailscale.IdentityFederation{ + ClientID: clientID, + IDTokenFunc: func() (string, error) { + token, err := os.ReadFile(oidcJWTPath) + if err != nil { + return "", err + } + + return string(token), nil }, } - httpClient = &http.Client{ - Transport: &oauth2.Transport{ - Source: tokenSrc, - }, - } - } - - c := tailscale.NewClient(defaultTailnet, nil) - c.UserAgent = "tailscale-k8s-operator" - c.HTTPClient = httpClient - if loginServer != "" { - c.BaseURL = loginServer - } - return c, nil -} - -type tsClient interface { - CreateKey(ctx context.Context, caps tailscale.KeyCapabilities) (string, *tailscale.Key, error) - Device(ctx context.Context, deviceID string, fields *tailscale.DeviceFieldsOpts) (*tailscale.Device, error) - DeleteDevice(ctx context.Context, nodeStableID string) error - // GetVIPService is a method for getting a Tailscale Service. VIPService is the original name for Tailscale Service. - GetVIPService(ctx context.Context, name tailcfg.ServiceName) (*tailscale.VIPService, error) - // ListVIPServices is a method for listing all Tailscale Services. VIPService is the original name for Tailscale Service. - ListVIPServices(ctx context.Context) (*tailscale.VIPServiceList, error) - // CreateOrUpdateVIPService is a method for creating or updating a Tailscale Service. - CreateOrUpdateVIPService(ctx context.Context, svc *tailscale.VIPService) error - // DeleteVIPService is a method for deleting a Tailscale Service. - DeleteVIPService(ctx context.Context, name tailcfg.ServiceName) error -} - -// jwtTokenSource implements the [oauth2.TokenSource] interface, but with the -// ability to regenerate a fresh underlying token source each time a new value -// of the JWT parameter is needed due to expiration. -type jwtTokenSource struct { - logger *zap.SugaredLogger - jwtPath string // Path to the file containing an automatically refreshed JWT. - baseCfg clientcredentials.Config // Holds config that doesn't change for the lifetime of the process. - - mu sync.Mutex // Guards underlying. - underlying oauth2.TokenSource // The oauth2 client implementation. Does its own separate caching of the access token. -} - -func (s *jwtTokenSource) Token() (*oauth2.Token, error) { - s.mu.Lock() - defer s.mu.Unlock() - - if s.underlying != nil { - t, err := s.underlying.Token() - if err == nil && t != nil && t.Valid() { - return t, nil - } - } - - s.logger.Debugf("Refreshing JWT from %s", s.jwtPath) - tk, err := os.ReadFile(s.jwtPath) - if err != nil { - return nil, fmt.Errorf("error reading JWT from %q: %w", s.jwtPath, err) - } - - // Shallow copy of the base config. - credentials := s.baseCfg - credentials.EndpointParams = map[string][]string{ - "jwt": {string(tk)}, } - src := credentials.TokenSource(context.Background()) - s.underlying = oauth2.ReuseTokenSourceWithExpiry(nil, src, time.Minute) - return s.underlying.Token() + return client, nil } diff --git a/cmd/k8s-operator/tsclient_test.go b/cmd/k8s-operator/tsclient_test.go deleted file mode 100644 index c08705c78ed8b..0000000000000 --- a/cmd/k8s-operator/tsclient_test.go +++ /dev/null @@ -1,135 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !plan9 - -package main - -import ( - "encoding/json" - "fmt" - "io" - "net/http" - "net/http/httptest" - "os" - "path/filepath" - "testing" - - "go.uber.org/zap" - "golang.org/x/oauth2" -) - -func TestNewStaticClient(t *testing.T) { - const ( - clientIDFile = "client-id" - clientSecretFile = "client-secret" - ) - - tmp := t.TempDir() - clientIDPath := filepath.Join(tmp, clientIDFile) - if err := os.WriteFile(clientIDPath, []byte("test-client-id"), 0600); err != nil { - t.Fatalf("error writing test file %q: %v", clientIDPath, err) - } - clientSecretPath := filepath.Join(tmp, clientSecretFile) - if err := os.WriteFile(clientSecretPath, []byte("test-client-secret"), 0600); err != nil { - t.Fatalf("error writing test file %q: %v", clientSecretPath, err) - } - - srv := testAPI(t, 3600) - cl, err := newTSClient(zap.NewNop().Sugar(), "", clientIDPath, clientSecretPath, srv.URL) - if err != nil { - t.Fatalf("error creating Tailscale client: %v", err) - } - - resp, err := cl.HTTPClient.Get(srv.URL) - if err != nil { - t.Fatalf("error making test API call: %v", err) - } - defer resp.Body.Close() - - got, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("error reading response body: %v", err) - } - want := "Bearer " + testToken("/api/v2/oauth/token", "test-client-id", "test-client-secret", "") - if string(got) != want { - t.Errorf("got %q; want %q", got, want) - } -} - -func TestNewWorkloadIdentityClient(t *testing.T) { - // 5 seconds is within expiryDelta leeway, so the access token will - // immediately be considered expired and get refreshed on each access. - srv := testAPI(t, 5) - cl, err := newTSClient(zap.NewNop().Sugar(), "test-client-id", "", "", srv.URL) - if err != nil { - t.Fatalf("error creating Tailscale client: %v", err) - } - - // Modify the path where the JWT will be read from. - oauth2Transport, ok := cl.HTTPClient.Transport.(*oauth2.Transport) - if !ok { - t.Fatalf("expected oauth2.Transport, got %T", cl.HTTPClient.Transport) - } - jwtTokenSource, ok := oauth2Transport.Source.(*jwtTokenSource) - if !ok { - t.Fatalf("expected jwtTokenSource, got %T", oauth2Transport.Source) - } - tmp := t.TempDir() - jwtPath := filepath.Join(tmp, "token") - jwtTokenSource.jwtPath = jwtPath - - for _, jwt := range []string{"test-jwt", "updated-test-jwt"} { - if err := os.WriteFile(jwtPath, []byte(jwt), 0600); err != nil { - t.Fatalf("error writing test file %q: %v", jwtPath, err) - } - resp, err := cl.HTTPClient.Get(srv.URL) - if err != nil { - t.Fatalf("error making test API call: %v", err) - } - defer resp.Body.Close() - - got, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("error reading response body: %v", err) - } - if want := "Bearer " + testToken("/api/v2/oauth/token-exchange", "test-client-id", "", jwt); string(got) != want { - t.Errorf("got %q; want %q", got, want) - } - } -} - -func testAPI(t *testing.T, expirationSeconds int) *httptest.Server { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - t.Logf("test server got request: %s %s", r.Method, r.URL.Path) - switch r.URL.Path { - case "/api/v2/oauth/token", "/api/v2/oauth/token-exchange": - id, secret, ok := r.BasicAuth() - if !ok { - t.Fatal("missing or invalid basic auth") - } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": testToken(r.URL.Path, id, secret, r.FormValue("jwt")), - "token_type": "Bearer", - "expires_in": expirationSeconds, - }); err != nil { - t.Fatalf("error writing response: %v", err) - } - case "/": - // Echo back the authz header for test assertions. - _, err := w.Write([]byte(r.Header.Get("Authorization"))) - if err != nil { - t.Fatalf("error writing response: %v", err) - } - default: - w.WriteHeader(http.StatusNotFound) - } - })) - t.Cleanup(srv.Close) - return srv -} - -func testToken(path, id, secret, jwt string) string { - return fmt.Sprintf("%s|%s|%s|%s", path, id, secret, jwt) -} diff --git a/cmd/k8s-operator/tsrecorder_specs.go b/cmd/k8s-operator/tsrecorder_specs.go index ab06c01f81b7d..5a93bc22b546c 100644 --- a/cmd/k8s-operator/tsrecorder_specs.go +++ b/cmd/k8s-operator/tsrecorder_specs.go @@ -7,6 +7,7 @@ package main import ( "fmt" + "maps" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" @@ -14,7 +15,6 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" - "tailscale.com/types/ptr" "tailscale.com/version" ) @@ -33,7 +33,7 @@ func tsrStatefulSet(tsr *tsapi.Recorder, namespace string, loginServer string) * Annotations: tsr.Spec.StatefulSet.Annotations, }, Spec: appsv1.StatefulSetSpec{ - Replicas: ptr.To(replicas), + Replicas: new(replicas), Selector: &metav1.LabelSelector{ MatchLabels: tsrLabels("recorder", tsr.Name, tsr.Spec.StatefulSet.Pod.Labels), }, @@ -313,9 +313,7 @@ func tsrEnv(tsr *tsapi.Recorder, loginServer string) []corev1.EnvVar { func tsrLabels(app, instance string, customLabels map[string]string) map[string]string { labels := make(map[string]string, len(customLabels)+3) - for k, v := range customLabels { - labels[k] = v - } + maps.Copy(labels, customLabels) // ref: https://kubernetes.io/docs/concepts/overview/working-with-objects/common-labels/ labels["app.kubernetes.io/name"] = app diff --git a/cmd/k8s-operator/tsrecorder_specs_test.go b/cmd/k8s-operator/tsrecorder_specs_test.go index 47997d1d31b0f..13da8a3c8781f 100644 --- a/cmd/k8s-operator/tsrecorder_specs_test.go +++ b/cmd/k8s-operator/tsrecorder_specs_test.go @@ -14,17 +14,16 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" - "tailscale.com/types/ptr" ) func TestRecorderSpecs(t *testing.T) { - t.Run("ensure spec fields are passed through correctly", func(t *testing.T) { + t.Run("spec-fields-passthrough", func(t *testing.T) { tsr := &tsapi.Recorder{ ObjectMeta: metav1.ObjectMeta{ Name: "test", }, Spec: tsapi.RecorderSpec{ - Replicas: ptr.To[int32](3), + Replicas: new(int32(3)), StatefulSet: tsapi.RecorderStatefulSet{ Labels: map[string]string{ "ss-label-key": "ss-label-value", @@ -51,7 +50,7 @@ func TestRecorderSpecs(t *testing.T) { }, }, SecurityContext: &corev1.PodSecurityContext{ - RunAsUser: ptr.To[int64](1000), + RunAsUser: new(int64(1000)), }, ImagePullSecrets: []corev1.LocalObjectReference{{ Name: "img-pull", @@ -62,7 +61,7 @@ func TestRecorderSpecs(t *testing.T) { Tolerations: []corev1.Toleration{{ Key: "key", Value: "value", - TolerationSeconds: ptr.To[int64](60), + TolerationSeconds: new(int64(60)), }}, Container: tsapi.RecorderContainer{ Env: []tsapi.Env{{ diff --git a/cmd/k8s-operator/tsrecorder_test.go b/cmd/k8s-operator/tsrecorder_test.go index 5d315f8c52e93..8f189728c0207 100644 --- a/cmd/k8s-operator/tsrecorder_test.go +++ b/cmd/k8s-operator/tsrecorder_test.go @@ -14,6 +14,7 @@ import ( "github.com/google/go-cmp/cmp" "go.uber.org/zap" + "golang.org/x/time/rate" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" rbacv1 "k8s.io/api/rbac/v1" @@ -21,11 +22,12 @@ import ( "k8s.io/client-go/tools/record" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" + "tailscale.com/client/tailscale/v2" tsoperator "tailscale.com/k8s-operator" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" + "tailscale.com/k8s-operator/tsclient" "tailscale.com/tstest" - "tailscale.com/types/ptr" ) const ( @@ -40,7 +42,7 @@ func TestRecorder(t *testing.T) { Finalizers: []string{"tailscale.com/finalizer"}, }, Spec: tsapi.RecorderSpec{ - Replicas: ptr.To[int32](3), + Replicas: new(int32(3)), }, } @@ -49,18 +51,19 @@ func TestRecorder(t *testing.T) { WithObjects(tsr). WithStatusSubresource(tsr). Build() - tsClient := &fakeTSClient{} + tsClient := &fakeTSClient{loginURL: tsLoginServer} zl, _ := zap.NewDevelopment() fr := record.NewFakeRecorder(2) cl := tstest.NewClock(tstest.ClockOpts{}) reconciler := &RecorderReconciler{ - tsNamespace: tsNamespace, - Client: fc, - tsClient: tsClient, - recorder: fr, - log: zl.Sugar(), - clock: cl, - loginServer: tsLoginServer, + tsNamespace: tsNamespace, + Client: fc, + clients: tsclient.NewProvider(tsClient), + recorder: fr, + log: zl.Sugar(), + clock: cl, + authKeyRateLimits: make(map[string]*rate.Limiter), + authKeyReissuing: make(map[string]bool), } t.Run("invalid_spec_gives_an_error_condition", func(t *testing.T) { @@ -195,8 +198,8 @@ func TestRecorder(t *testing.T) { }) t.Run("populate_node_info_in_state_secret_and_see_it_appear_in_status", func(t *testing.T) { - const key = "profile-abc" + for replica := range *tsr.Spec.Replicas { bytes, err := json.Marshal(map[string]any{ "Config": map[string]any{ @@ -219,6 +222,24 @@ func TestRecorder(t *testing.T) { }) } + tsClient.devices = []tailscale.Device{ + { + ID: "node-0", + Hostname: "hostname-node-0", + Addresses: []string{"1.2.3.4", "::1"}, + }, + { + ID: "node-1", + Hostname: "hostname-node-1", + Addresses: []string{"1.2.3.4", "::1"}, + }, + { + ID: "node-2", + Hostname: "hostname-node-2", + Addresses: []string{"1.2.3.4", "::1"}, + }, + } + expectReconciled(t, reconciler, "", tsr.Name) tsr.Status.Devices = []tsapi.RecorderTailnetDevice{ { diff --git a/cmd/k8s-proxy/internal/config/config.go b/cmd/k8s-proxy/internal/config/config.go index 91b4c54a5c32d..c12383d45c470 100644 --- a/cmd/k8s-proxy/internal/config/config.go +++ b/cmd/k8s-proxy/internal/config/config.go @@ -27,7 +27,6 @@ import ( clientcorev1 "k8s.io/client-go/kubernetes/typed/core/v1" "tailscale.com/kube/k8s-proxy/conf" "tailscale.com/kube/kubetypes" - "tailscale.com/types/ptr" "tailscale.com/util/testenv" ) @@ -178,7 +177,7 @@ func (ld *configLoader) watchConfigSecretChanges(ctx context.Context, secretName }, // Re-watch regularly to avoid relying on long-lived connections. // See https://github.com/kubernetes-client/javascript/issues/596#issuecomment-786419380 - TimeoutSeconds: ptr.To(int64(600)), + TimeoutSeconds: new(int64(600)), FieldSelector: fmt.Sprintf("metadata.name=%s", secretName), Watch: true, }) @@ -216,7 +215,7 @@ func (ld *configLoader) watchConfigSecretChanges(ctx context.Context, secretName Kind: "Secret", APIVersion: "v1", }, - TimeoutSeconds: ptr.To(int64(600)), + TimeoutSeconds: new(int64(600)), FieldSelector: fmt.Sprintf("metadata.name=%s", secretName), Watch: true, }) diff --git a/cmd/k8s-proxy/internal/config/config_test.go b/cmd/k8s-proxy/internal/config/config_test.go index ac6c6cf93f623..aedd29d4e1877 100644 --- a/cmd/k8s-proxy/internal/config/config_test.go +++ b/cmd/k8s-proxy/internal/config/config_test.go @@ -20,7 +20,6 @@ import ( ktesting "k8s.io/client-go/testing" "tailscale.com/kube/k8s-proxy/conf" "tailscale.com/kube/kubetypes" - "tailscale.com/types/ptr" ) func TestWatchConfig(t *testing.T) { @@ -52,7 +51,7 @@ func TestWatchConfig(t *testing.T) { initialConfig: `{"version": "v1alpha1", "authKey": "abc123"}`, phases: []phase{{ expectedConf: &conf.ConfigV1Alpha1{ - AuthKey: ptr.To("abc123"), + AuthKey: new("abc123"), }, }}, }, @@ -62,7 +61,7 @@ func TestWatchConfig(t *testing.T) { phases: []phase{ { expectedConf: &conf.ConfigV1Alpha1{ - AuthKey: ptr.To("abc123"), + AuthKey: new("abc123"), }, }, { @@ -76,13 +75,13 @@ func TestWatchConfig(t *testing.T) { phases: []phase{ { expectedConf: &conf.ConfigV1Alpha1{ - AuthKey: ptr.To("abc123"), + AuthKey: new("abc123"), }, }, { config: `{"version": "v1alpha1", "authKey": "def456"}`, expectedConf: &conf.ConfigV1Alpha1{ - AuthKey: ptr.To("def456"), + AuthKey: new("def456"), }, }, }, @@ -93,7 +92,7 @@ func TestWatchConfig(t *testing.T) { phases: []phase{ { expectedConf: &conf.ConfigV1Alpha1{ - AuthKey: ptr.To("abc123"), + AuthKey: new("abc123"), }, }, { diff --git a/cmd/k8s-proxy/k8s-proxy.go b/cmd/k8s-proxy/k8s-proxy.go index 38a86a5e0ade5..673493f58cecd 100644 --- a/cmd/k8s-proxy/k8s-proxy.go +++ b/cmd/k8s-proxy/k8s-proxy.go @@ -31,6 +31,7 @@ import ( "k8s.io/utils/strings/slices" "tailscale.com/client/local" "tailscale.com/cmd/k8s-proxy/internal/config" + "tailscale.com/health" "tailscale.com/hostinfo" "tailscale.com/ipn" "tailscale.com/ipn/store" @@ -41,6 +42,7 @@ import ( "tailscale.com/kube/certs" healthz "tailscale.com/kube/health" "tailscale.com/kube/k8s-proxy/conf" + "tailscale.com/kube/kubeclient" "tailscale.com/kube/kubetypes" klc "tailscale.com/kube/localclient" "tailscale.com/kube/metrics" @@ -171,10 +173,31 @@ func run(logger *zap.SugaredLogger) error { // If Pod UID unset, assume we're running outside of a cluster/not managed // by the operator, so no need to set additional state keys. + var kc kubeclient.Client + var stateSecretName string if podUID != "" { if err := state.SetInitialKeys(st, podUID); err != nil { return fmt.Errorf("error setting initial state: %w", err) } + + if cfg.Parsed.State != nil { + if name, ok := strings.CutPrefix(*cfg.Parsed.State, "kube:"); ok { + stateSecretName = name + + kc, err = kubeclient.New(k8sProxyFieldManager) + if err != nil { + return err + } + + var configAuthKey string + if cfg.Parsed.AuthKey != nil { + configAuthKey = *cfg.Parsed.AuthKey + } + if err := resetState(ctx, kc, stateSecretName, podUID, configAuthKey); err != nil { + return fmt.Errorf("error resetting state: %w", err) + } + } + } } var authKey string @@ -197,23 +220,69 @@ func run(logger *zap.SugaredLogger) error { ts.Hostname = *cfg.Parsed.Hostname } + lc, err := ts.LocalClient() + if err != nil { + return fmt.Errorf("error getting local client: %w", err) + } + // Make sure we crash loop if Up doesn't complete in reasonable time. - upCtx, upCancel := context.WithTimeout(ctx, time.Minute) + upCtx, upCancel := context.WithTimeout(ctx, 30*time.Second) defer upCancel() + + // ts.Up() deliberately ignores NeedsLogin because it fires transiently + // during normal auth-key login. We can watch for the login-state health + // warning here though, which only fires on terminal auth failure, and + // cancel early. + go func() { + w, err := lc.WatchIPNBus(upCtx, ipn.NotifyInitialHealthState) + if err != nil { + return + } + defer w.Close() + for { + n, err := w.Next() + if err != nil { + logger.Debugf("failed to process message from ipn bus: %s", err.Error()) + return + } + if n.Health != nil { + if _, ok := n.Health.Warnings[health.LoginStateWarnable.Code]; ok { + upCancel() + return + } + } + } + }() + if _, err := ts.Up(upCtx); err != nil { - return fmt.Errorf("error starting tailscale server: %w", err) + if kc != nil && stateSecretName != "" { + return handleAuthKeyReissue(ctx, lc, kc, stateSecretName, authKey, cfgChan, logger) + } + return err } + defer ts.Close() - lc, err := ts.LocalClient() - if err != nil { - return fmt.Errorf("error getting local client: %w", err) - } - // Setup for updating state keys. + reissueCh := make(chan struct{}, 1) if podUID != "" { group.Go(func() error { return state.KeepKeysUpdated(ctx, st, klc.New(lc)) }) + + if kc != nil && stateSecretName != "" { + needsReissue, err := checkInitialAuthState(ctx, lc) + if err != nil { + return fmt.Errorf("error checking initial auth state: %w", err) + } + if needsReissue { + logger.Info("Auth key missing or invalid after startup, requesting new key from operator") + return handleAuthKeyReissue(ctx, lc, kc, stateSecretName, authKey, cfgChan, logger) + } + + group.Go(func() error { + return monitorAuthHealth(ctx, lc, reissueCh, logger) + }) + } } if cfg.Parsed.HealthCheckEnabled.EqualBool(true) || cfg.Parsed.MetricsEnabled.EqualBool(true) { @@ -362,6 +431,8 @@ func run(logger *zap.SugaredLogger) error { } cfgLogger.Infof("Config reloaded") + case <-reissueCh: + return handleAuthKeyReissue(ctx, lc, kc, stateSecretName, authKey, cfgChan, logger) } } } diff --git a/cmd/k8s-proxy/kube.go b/cmd/k8s-proxy/kube.go new file mode 100644 index 0000000000000..1d9348f1a3bea --- /dev/null +++ b/cmd/k8s-proxy/kube.go @@ -0,0 +1,161 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "context" + "fmt" + "strings" + "sync" + "time" + + "go.uber.org/zap" + "tailscale.com/client/local" + "tailscale.com/health" + "tailscale.com/ipn" + "tailscale.com/kube/authkey" + "tailscale.com/kube/k8s-proxy/conf" + "tailscale.com/kube/kubeapi" + "tailscale.com/kube/kubeclient" + "tailscale.com/kube/kubetypes" + "tailscale.com/tailcfg" +) + +const k8sProxyFieldManager = "tailscale-k8s-proxy" + +// resetState clears k8s-proxy state from previous runs and sets +// initial values. This ensures the operator doesn't use stale state when a Pod +// is first recreated. +// +// It also clears the reissue_authkey marker if the operator has actioned it +// (i.e., the config now has a different auth key than what was marked for +// reissue). +func resetState(ctx context.Context, kc kubeclient.Client, stateSecretName string, podUID string, configAuthKey string) error { + existingSecret, err := kc.GetSecret(ctx, stateSecretName) + switch { + case kubeclient.IsNotFoundErr(err): + return nil + case err != nil: + return fmt.Errorf("failed to read state Secret %q to reset state: %w", stateSecretName, err) + } + + s := &kubeapi.Secret{ + Data: map[string][]byte{ + kubetypes.KeyCapVer: fmt.Appendf(nil, "%d", tailcfg.CurrentCapabilityVersion), + }, + } + if podUID != "" { + s.Data[kubetypes.KeyPodUID] = []byte(podUID) + } + + // Only clear reissue_authkey if the operator has actioned it. + brokenAuthkey, ok := existingSecret.Data[kubetypes.KeyReissueAuthkey] + if ok && configAuthKey != "" && string(brokenAuthkey) != configAuthKey { + s.Data[kubetypes.KeyReissueAuthkey] = nil + } + + return kc.StrategicMergePatchSecret(ctx, stateSecretName, s, k8sProxyFieldManager) +} + +// needsAuthKeyReissue reports whether the given backend state and health +// warnings indicate a terminal auth failure requiring a new key from the +// operator. +func needsAuthKeyReissue(backendState string, healthWarnings []string) bool { + if backendState == ipn.NeedsLogin.String() { + return true + } + loginWarnableCode := string(health.LoginStateWarnable.Code) + for _, h := range healthWarnings { + if strings.Contains(h, loginWarnableCode) { + return true + } + } + return false +} + +// checkInitialAuthState checks if the tsnet server is in an auth failure state +// immediately after coming up. Returns true if auth key reissue is needed. +func checkInitialAuthState(ctx context.Context, lc *local.Client) (bool, error) { + status, err := lc.Status(ctx) + if err != nil { + return false, fmt.Errorf("error getting status: %w", err) + } + return needsAuthKeyReissue(status.BackendState, status.Health), nil +} + +// monitorAuthHealth watches the IPN bus for auth failures and triggers reissue +// when needed. Runs until context is cancelled or auth failure is detected. +func monitorAuthHealth(ctx context.Context, lc *local.Client, reissueCh chan<- struct{}, logger *zap.SugaredLogger) error { + w, err := lc.WatchIPNBus(ctx, ipn.NotifyInitialHealthState) + if err != nil { + return fmt.Errorf("failed to watch IPN bus for auth health: %w", err) + } + defer w.Close() + + for { + if ctx.Err() != nil { + return ctx.Err() + } + n, err := w.Next() + if err != nil { + return err + } + if n.Health != nil { + if _, ok := n.Health.Warnings[health.LoginStateWarnable.Code]; ok { + logger.Info("Auth key failed to authenticate (may be expired or single-use), requesting new key from operator") + select { + case reissueCh <- struct{}{}: + case <-ctx.Done(): + } + return nil + } + } + } +} + +// handleAuthKeyReissue orchestrates the auth key reissue flow: +// 1. Disconnect from control +// 2. Set reissue marker in state Secret +// 3. Wait for operator to provide new key +// 4. Exit cleanly (Kubernetes will restart the pod with the new key) +func handleAuthKeyReissue(ctx context.Context, lc *local.Client, kc kubeclient.Client, stateSecretName string, currentAuthKey string, cfgChan <-chan *conf.Config, logger *zap.SugaredLogger) error { + if err := lc.DisconnectControl(ctx); err != nil { + return fmt.Errorf("error disconnecting from control: %w", err) + } + if err := authkey.SetReissueAuthKey(ctx, kc, stateSecretName, currentAuthKey, k8sProxyFieldManager); err != nil { + return fmt.Errorf("failed to set reissue_authkey in Kubernetes Secret: %w", err) + } + + var mu sync.Mutex + var latestAuthKey string + notify := make(chan struct{}, 1) + + // we use this go func to abstract away conf.Config from the shared function + go func() { + for cfg := range cfgChan { + if cfg.Parsed.AuthKey != nil { + mu.Lock() + latestAuthKey = *cfg.Parsed.AuthKey + mu.Unlock() + select { + case notify <- struct{}{}: + default: + } + } + } + }() + + getAuthKey := func() string { + mu.Lock() + defer mu.Unlock() + return latestAuthKey + } + clearFn := func(ctx context.Context) error { + return authkey.ClearReissueAuthKey(ctx, kc, stateSecretName, k8sProxyFieldManager) + } + + return authkey.WaitForAuthKeyReissue(ctx, currentAuthKey, 10*time.Minute, getAuthKey, clearFn, notify) +} diff --git a/cmd/k8s-proxy/kube_test.go b/cmd/k8s-proxy/kube_test.go new file mode 100644 index 0000000000000..c7e0f33d02b9e --- /dev/null +++ b/cmd/k8s-proxy/kube_test.go @@ -0,0 +1,141 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "context" + "fmt" + "testing" + + "github.com/google/go-cmp/cmp" + "tailscale.com/health" + "tailscale.com/kube/kubeapi" + "tailscale.com/kube/kubeclient" + "tailscale.com/kube/kubetypes" + "tailscale.com/tailcfg" +) + +func TestResetState(t *testing.T) { + tests := []struct { + name string + existingData map[string][]byte + podUID string + configAuthKey string + wantPatched map[string][]byte + }{ + { + name: "sets_capver_and_pod_uid", + existingData: map[string][]byte{ + kubetypes.KeyDeviceID: []byte("device-123"), + kubetypes.KeyDeviceFQDN: []byte("node.tailnet"), + kubetypes.KeyDeviceIPs: []byte(`["100.64.0.1"]`), + }, + podUID: "pod-123", + configAuthKey: "new-key", + wantPatched: map[string][]byte{ + kubetypes.KeyPodUID: []byte("pod-123"), + }, + }, + { + name: "clears_reissue_marker_when_actioned", + existingData: map[string][]byte{ + kubetypes.KeyReissueAuthkey: []byte("old-key"), + }, + podUID: "pod-123", + configAuthKey: "new-key", + wantPatched: map[string][]byte{ + kubetypes.KeyPodUID: []byte("pod-123"), + kubetypes.KeyReissueAuthkey: nil, + }, + }, + { + name: "keeps_reissue_marker_when_not_actioned", + existingData: map[string][]byte{ + kubetypes.KeyReissueAuthkey: []byte("old-key"), + }, + podUID: "pod-123", + configAuthKey: "old-key", + wantPatched: map[string][]byte{ + kubetypes.KeyPodUID: []byte("pod-123"), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.wantPatched[kubetypes.KeyCapVer] = fmt.Appendf(nil, "%d", tailcfg.CurrentCapabilityVersion) + + var patched map[string][]byte + kc := &kubeclient.FakeClient{ + GetSecretImpl: func(ctx context.Context, name string) (*kubeapi.Secret, error) { + return &kubeapi.Secret{Data: tt.existingData}, nil + }, + StrategicMergePatchSecretImpl: func(ctx context.Context, name string, s *kubeapi.Secret, fm string) error { + patched = s.Data + return nil + }, + } + + err := resetState(context.Background(), kc, "test-secret", tt.podUID, tt.configAuthKey) + if err != nil { + t.Fatalf("resetState() error = %v", err) + } + + if diff := cmp.Diff(tt.wantPatched, patched); diff != "" { + t.Errorf("resetState() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestNeedsAuthKeyReissue(t *testing.T) { + loginWarnableCode := string(health.LoginStateWarnable.Code) + + tests := []struct { + name string + backendState string + health []string + want bool + }{ + { + name: "running_healthy", + backendState: "Running", + want: false, + }, + { + name: "needs_login", + backendState: "NeedsLogin", + want: true, + }, + { + name: "running_with_login_warning", + backendState: "Running", + health: []string{"warning: " + loginWarnableCode + ": you are logged out"}, + want: true, + }, + { + name: "running_with_unrelated_warning", + backendState: "Running", + health: []string{"dns-not-working"}, + want: false, + }, + { + name: "running_no_warnings", + backendState: "Running", + health: nil, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := needsAuthKeyReissue(tt.backendState, tt.health) + if got != tt.want { + t.Errorf("needsAuthKeyReissue() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/cmd/mkpkg/main.go b/cmd/mkpkg/main.go index 6f4de7e299b50..ecf108c2ec236 100644 --- a/cmd/mkpkg/main.go +++ b/cmd/mkpkg/main.go @@ -24,7 +24,7 @@ func parseFiles(s string, typ string) (files.Contents, error) { return nil, nil } var contents files.Contents - for _, f := range strings.Split(s, ",") { + for f := range strings.SplitSeq(s, ",") { fs := strings.Split(f, ":") if len(fs) != 2 { return nil, fmt.Errorf("unparseable file field %q", f) @@ -41,7 +41,7 @@ func parseEmptyDirs(s string) files.Contents { return nil } var contents files.Contents - for _, d := range strings.Split(s, ",") { + for d := range strings.SplitSeq(s, ",") { contents = append(contents, &files.Content{Type: files.TypeDir, Destination: d}) } return contents diff --git a/cmd/nardump/nardump.go b/cmd/nardump/nardump.go index c8db24cb6736d..38a2a67319595 100644 --- a/cmd/nardump/nardump.go +++ b/cmd/nardump/nardump.go @@ -9,22 +9,13 @@ // git-pull-oss.sh having Nix available. package main -// For the format, see: -// See https://gist.github.com/jbeda/5c79d2b1434f0018d693 - import ( - "bufio" - "crypto/sha256" - "encoding/base64" - "encoding/binary" "flag" "fmt" - "io" - "io/fs" "log" "os" - "path" - "sort" + + "tailscale.com/cmd/nardump/nardump" ) var sri = flag.Bool("sri", false, "print SRI") @@ -34,167 +25,16 @@ func main() { if flag.NArg() != 1 { log.Fatal("usage: nardump ") } - arg := flag.Arg(0) - if err := os.Chdir(arg); err != nil { - log.Fatal(err) - } + fsys := os.DirFS(flag.Arg(0)) if *sri { - hash := sha256.New() - if err := writeNAR(hash, os.DirFS(".")); err != nil { + s, err := nardump.SRI(fsys) + if err != nil { log.Fatal(err) } - fmt.Printf("sha256-%s\n", base64.StdEncoding.EncodeToString(hash.Sum(nil))) + fmt.Println(s) return } - bw := bufio.NewWriter(os.Stdout) - if err := writeNAR(bw, os.DirFS(".")); err != nil { + if err := nardump.WriteNAR(os.Stdout, fsys); err != nil { log.Fatal(err) } - bw.Flush() -} - -// writeNARError is a sentinel panic type that's recovered by writeNAR -// and converted into the wrapped error. -type writeNARError struct{ err error } - -// narWriter writes NAR files. -type narWriter struct { - w io.Writer - fs fs.FS -} - -// writeNAR writes a NAR file to w from the root of fs. -func writeNAR(w io.Writer, fs fs.FS) (err error) { - defer func() { - if e := recover(); e != nil { - if we, ok := e.(writeNARError); ok { - err = we.err - return - } - panic(e) - } - }() - nw := &narWriter{w: w, fs: fs} - nw.str("nix-archive-1") - return nw.writeDir(".") -} - -func (nw *narWriter) writeDir(dirPath string) error { - ents, err := fs.ReadDir(nw.fs, dirPath) - if err != nil { - return err - } - sort.Slice(ents, func(i, j int) bool { - return ents[i].Name() < ents[j].Name() - }) - nw.str("(") - nw.str("type") - nw.str("directory") - for _, ent := range ents { - nw.str("entry") - nw.str("(") - nw.str("name") - nw.str(ent.Name()) - nw.str("node") - mode := ent.Type() - sub := path.Join(dirPath, ent.Name()) - var err error - switch { - case mode.IsDir(): - err = nw.writeDir(sub) - case mode.IsRegular(): - err = nw.writeRegular(sub) - case mode&os.ModeSymlink != 0: - err = nw.writeSymlink(sub) - default: - return fmt.Errorf("unsupported file type %v at %q", sub, mode) - } - if err != nil { - return err - } - nw.str(")") - } - nw.str(")") - return nil -} - -func (nw *narWriter) writeRegular(path string) error { - nw.str("(") - nw.str("type") - nw.str("regular") - fi, err := fs.Stat(nw.fs, path) - if err != nil { - return err - } - if fi.Mode()&0111 != 0 { - nw.str("executable") - nw.str("") - } - contents, err := fs.ReadFile(nw.fs, path) - if err != nil { - return err - } - nw.str("contents") - if err := writeBytes(nw.w, contents); err != nil { - return err - } - nw.str(")") - return nil -} - -func (nw *narWriter) writeSymlink(path string) error { - nw.str("(") - nw.str("type") - nw.str("symlink") - nw.str("target") - // broken symlinks are valid in a nar - // given we do os.chdir(dir) and os.dirfs(".") above - // readlink now resolves relative links even if they are broken - link, err := os.Readlink(path) - if err != nil { - return err - } - nw.str(link) - nw.str(")") - return nil -} - -func (nw *narWriter) str(s string) { - if err := writeString(nw.w, s); err != nil { - panic(writeNARError{err}) - } -} - -func writeString(w io.Writer, s string) error { - var buf [8]byte - binary.LittleEndian.PutUint64(buf[:], uint64(len(s))) - if _, err := w.Write(buf[:]); err != nil { - return err - } - if _, err := io.WriteString(w, s); err != nil { - return err - } - return writePad(w, len(s)) -} - -func writeBytes(w io.Writer, b []byte) error { - var buf [8]byte - binary.LittleEndian.PutUint64(buf[:], uint64(len(b))) - if _, err := w.Write(buf[:]); err != nil { - return err - } - if _, err := w.Write(b); err != nil { - return err - } - return writePad(w, len(b)) -} - -func writePad(w io.Writer, n int) error { - pad := n % 8 - if pad == 0 { - return nil - } - var zeroes [8]byte - _, err := w.Write(zeroes[:8-pad]) - return err } diff --git a/cmd/nardump/nardump/nardump.go b/cmd/nardump/nardump/nardump.go new file mode 100644 index 0000000000000..ab9ff1f3cdcd8 --- /dev/null +++ b/cmd/nardump/nardump/nardump.go @@ -0,0 +1,193 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// Package nardump writes a NAR (Nix Archive) representation of an +// fs.FS to an io.Writer, or summarizes it as a Subresource Integrity +// hash, as used by Nix flake.nix vendor and toolchain hashes. +// +// For the format, see: +// https://gist.github.com/jbeda/5c79d2b1434f0018d693 +package nardump + +import ( + "bufio" + "crypto/sha256" + "encoding/base64" + "encoding/binary" + "fmt" + "io" + "io/fs" + "path" + "sort" +) + +// WriteNAR writes a NAR-encoded representation of fsys, rooted at +// the FS root, to w. +// +// The encoder issues many small writes; if w is not already a +// *bufio.Writer, WriteNAR wraps it in one and flushes on return so +// the caller doesn't have to. +// +// fsys must implement fs.ReadLinkFS to encode any symlinks it +// contains; os.DirFS satisfies this on Go 1.25+. +func WriteNAR(w io.Writer, fsys fs.FS) (err error) { + defer func() { + if e := recover(); e != nil { + if we, ok := e.(writeNARError); ok { + err = we.err + return + } + panic(e) + } + }() + bw, ok := w.(*bufio.Writer) + if !ok { + bw = bufio.NewWriter(w) + defer func() { + if flushErr := bw.Flush(); err == nil { + err = flushErr + } + }() + } + nw := &narWriter{w: bw, fs: fsys} + nw.str("nix-archive-1") + return nw.writeDir(".") +} + +// SRI returns the Subresource Integrity hash of the NAR encoding of +// fsys, in the form "sha256-". This is the format Nix +// expects for vendorHash and similar fields. +func SRI(fsys fs.FS) (string, error) { + h := sha256.New() + if err := WriteNAR(h, fsys); err != nil { + return "", err + } + return "sha256-" + base64.StdEncoding.EncodeToString(h.Sum(nil)), nil +} + +// writeNARError is a sentinel panic type that's recovered by +// WriteNAR and converted into the wrapped error. +type writeNARError struct{ err error } + +// narWriter writes NAR files. +type narWriter struct { + w io.Writer + fs fs.FS +} + +func (nw *narWriter) writeDir(dirPath string) error { + ents, err := fs.ReadDir(nw.fs, dirPath) + if err != nil { + return err + } + sort.Slice(ents, func(i, j int) bool { + return ents[i].Name() < ents[j].Name() + }) + nw.str("(") + nw.str("type") + nw.str("directory") + for _, ent := range ents { + nw.str("entry") + nw.str("(") + nw.str("name") + nw.str(ent.Name()) + nw.str("node") + mode := ent.Type() + sub := path.Join(dirPath, ent.Name()) + var err error + switch { + case mode.IsDir(): + err = nw.writeDir(sub) + case mode.IsRegular(): + err = nw.writeRegular(sub) + case mode&fs.ModeSymlink != 0: + err = nw.writeSymlink(sub) + default: + return fmt.Errorf("unsupported file type %v at %q", sub, mode) + } + if err != nil { + return err + } + nw.str(")") + } + nw.str(")") + return nil +} + +func (nw *narWriter) writeRegular(p string) error { + nw.str("(") + nw.str("type") + nw.str("regular") + fi, err := fs.Stat(nw.fs, p) + if err != nil { + return err + } + if fi.Mode()&0111 != 0 { + nw.str("executable") + nw.str("") + } + contents, err := fs.ReadFile(nw.fs, p) + if err != nil { + return err + } + nw.str("contents") + if err := writeBytes(nw.w, contents); err != nil { + return err + } + nw.str(")") + return nil +} + +func (nw *narWriter) writeSymlink(p string) error { + nw.str("(") + nw.str("type") + nw.str("symlink") + nw.str("target") + link, err := fs.ReadLink(nw.fs, p) + if err != nil { + return err + } + nw.str(link) + nw.str(")") + return nil +} + +func (nw *narWriter) str(s string) { + if err := writeString(nw.w, s); err != nil { + panic(writeNARError{err}) + } +} + +func writeString(w io.Writer, s string) error { + var buf [8]byte + binary.LittleEndian.PutUint64(buf[:], uint64(len(s))) + if _, err := w.Write(buf[:]); err != nil { + return err + } + if _, err := io.WriteString(w, s); err != nil { + return err + } + return writePad(w, len(s)) +} + +func writeBytes(w io.Writer, b []byte) error { + var buf [8]byte + binary.LittleEndian.PutUint64(buf[:], uint64(len(b))) + if _, err := w.Write(buf[:]); err != nil { + return err + } + if _, err := w.Write(b); err != nil { + return err + } + return writePad(w, len(b)) +} + +func writePad(w io.Writer, n int) error { + pad := n % 8 + if pad == 0 { + return nil + } + var zeroes [8]byte + _, err := w.Write(zeroes[:8-pad]) + return err +} diff --git a/cmd/nardump/nardump/nardump_test.go b/cmd/nardump/nardump/nardump_test.go new file mode 100644 index 0000000000000..16b690ee257f0 --- /dev/null +++ b/cmd/nardump/nardump/nardump_test.go @@ -0,0 +1,55 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package nardump + +import ( + "crypto/sha256" + "fmt" + "os" + "path/filepath" + "runtime" + "testing" +) + +// setupTmpdir sets up a known golden layout, covering all allowed file/folder types in a nar. +func setupTmpdir(t *testing.T) string { + t.Helper() + tmpdir := t.TempDir() + must := func(err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } + } + must(os.MkdirAll(filepath.Join(tmpdir, "sub/dir"), 0755)) + must(os.Symlink("brokenfile", filepath.Join(tmpdir, "brokenlink"))) + must(os.Symlink("sub/dir", filepath.Join(tmpdir, "dirl"))) + must(os.Symlink("/abs/nonexistentdir", filepath.Join(tmpdir, "dirb"))) + f, err := os.Create(filepath.Join(tmpdir, "sub/dir/file1")) + must(err) + f.Close() + f, err = os.Create(filepath.Join(tmpdir, "file2m")) + must(err) + must(f.Truncate(2 * 1024 * 1024)) + f.Close() + must(os.Symlink("../file2m", filepath.Join(tmpdir, "sub/goodlink"))) + return tmpdir +} + +func TestWriteNAR(t *testing.T) { + if runtime.GOOS == "windows" { + // Skip test on Windows as the Nix package manager is not supported on this platform + t.Skip("nix package manager is not available on Windows") + } + dir := setupTmpdir(t) + // obtained via `nix-store --dump /tmp/... | sha256sum` of the above test dir + const expected = "727613a36f41030e93a4abf2649c3ec64a2757ccff364e3f6f7d544eb976e442" + h := sha256.New() + if err := WriteNAR(h, os.DirFS(dir)); err != nil { + t.Fatal(err) + } + if got := fmt.Sprintf("%x", h.Sum(nil)); got != expected { + t.Fatalf("sha256sum of nar: got %s, want %s", got, expected) + } +} diff --git a/cmd/nardump/nardump_test.go b/cmd/nardump/nardump_test.go deleted file mode 100644 index c1ca825e1e288..0000000000000 --- a/cmd/nardump/nardump_test.go +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -package main - -import ( - "crypto/sha256" - "fmt" - "os" - "runtime" - "testing" -) - -// setupTmpdir sets up a known golden layout, covering all allowed file/folder types in a nar -func setupTmpdir(t *testing.T) string { - tmpdir := t.TempDir() - pwd, _ := os.Getwd() - os.Chdir(tmpdir) - defer os.Chdir(pwd) - os.MkdirAll("sub/dir", 0755) - os.Symlink("brokenfile", "brokenlink") - os.Symlink("sub/dir", "dirl") - os.Symlink("/abs/nonexistentdir", "dirb") - os.Create("sub/dir/file1") - f, _ := os.Create("file2m") - _ = f.Truncate(2 * 1024 * 1024) - f.Close() - os.Symlink("../file2m", "sub/goodlink") - return tmpdir -} - -func TestWriteNar(t *testing.T) { - if runtime.GOOS == "windows" { - // Skip test on Windows as the Nix package manager is not supported on this platform - t.Skip("nix package manager is not available on Windows") - } - dir := setupTmpdir(t) - t.Run("nar", func(t *testing.T) { - // obtained via `nix-store --dump /tmp/... | sha256sum` of the above test dir - expected := "727613a36f41030e93a4abf2649c3ec64a2757ccff364e3f6f7d544eb976e442" - h := sha256.New() - os.Chdir(dir) - err := writeNAR(h, os.DirFS(".")) - if err != nil { - t.Fatal(err) - } - hash := fmt.Sprintf("%x", h.Sum(nil)) - if expected != hash { - t.Fatal("sha256sum of nar not matched", hash, expected) - } - }) -} diff --git a/cmd/natc/ippool/ippool_test.go b/cmd/natc/ippool/ippool_test.go index 405ec61564ed8..af0053c2f54d8 100644 --- a/cmd/natc/ippool/ippool_test.go +++ b/cmd/natc/ippool/ippool_test.go @@ -30,7 +30,7 @@ func TestIPPoolExhaustion(t *testing.T) { from := tailcfg.NodeID(12345) - for i := 0; i < 5; i++ { + for range 5 { for _, domain := range domains { addr, err := pool.IPForDomain(from, domain) if err != nil { diff --git a/cmd/natc/natc.go b/cmd/natc/natc.go index 11975b7d2e1a6..877f16cc02689 100644 --- a/cmd/natc/natc.go +++ b/cmd/natc/natc.go @@ -82,14 +82,14 @@ func main() { log.Fatalf("site-id must be in the range [0, 65535]") } - var ignoreDstTable *bart.Table[bool] + var ignoreDstTable *bart.Lite for s := range strings.SplitSeq(*ignoreDstPfxStr, ",") { s := strings.TrimSpace(s) if s == "" { continue } if ignoreDstTable == nil { - ignoreDstTable = &bart.Table[bool]{} + ignoreDstTable = &bart.Lite{} } pfx, err := netip.ParsePrefix(s) if err != nil { @@ -98,7 +98,7 @@ func main() { if pfx.Masked() != pfx { log.Fatalf("prefix %v is not normalized (bits are set outside the mask)", pfx) } - ignoreDstTable.Insert(pfx, true) + ignoreDstTable.Insert(pfx) } ts := &tsnet.Server{ Hostname: *hostname, @@ -149,7 +149,7 @@ func main() { } var prefixes []netip.Prefix - for _, s := range strings.Split(*v4PfxStr, ",") { + for s := range strings.SplitSeq(*v4PfxStr, ",") { p := netip.MustParsePrefix(strings.TrimSpace(s)) if p.Masked() != p { log.Fatalf("v4 prefix %v is not a masked prefix", p) @@ -276,7 +276,7 @@ type connector struct { // and if any of the ip addresses in response to the lookup match any 'ignore destinations' prefix we will // return a dns response that contains the ip addresses we discovered with the lookup (ie not the // natc behavior, which would return a dummy ip address pointing at natc). - ignoreDsts *bart.Table[bool] + ignoreDsts *bart.Lite // ipPool contains the per-peer IPv4 address assignments. ipPool ippool.IPPool @@ -372,8 +372,7 @@ func (c *connector) handleDNS(pc net.PacketConn, buf []byte, remoteAddr *net.UDP addrQCount++ if _, ok := resolves[q.Name.String()]; !ok { addrs, err := c.resolver.LookupNetIP(ctx, "ip", q.Name.String()) - var dnsErr *net.DNSError - if errors.As(err, &dnsErr) && dnsErr.IsNotFound { + if dnsErr, ok := errors.AsType[*net.DNSError](err); ok && dnsErr.IsNotFound { continue } if err != nil { @@ -539,7 +538,7 @@ func (c *connector) ignoreDestination(dstAddrs []netip.Addr) bool { return false } for _, a := range dstAddrs { - if _, ok := c.ignoreDsts.Lookup(a); ok { + if c.ignoreDsts.Contains(a) { return true } } diff --git a/cmd/natc/natc_test.go b/cmd/natc/natc_test.go index e1cc061234d0e..00c94868ec8a2 100644 --- a/cmd/natc/natc_test.go +++ b/cmd/natc/natc_test.go @@ -268,13 +268,13 @@ func TestDNSResponse(t *testing.T) { }, }, }, - ignoreDsts: &bart.Table[bool]{}, + ignoreDsts: &bart.Lite{}, routes: routes, v6ULA: v6ULA, ipPool: &ippool.SingleMachineIPPool{IPSet: addrPool}, dnsAddr: dnsAddr, } - c.ignoreDsts.Insert(netip.MustParsePrefix("8.8.4.4/32"), true) + c.ignoreDsts.Insert(netip.MustParsePrefix("8.8.4.4/32")) for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { @@ -411,9 +411,9 @@ func TestDNSResponse(t *testing.T) { } func TestIgnoreDestination(t *testing.T) { - ignoreDstTable := &bart.Table[bool]{} - ignoreDstTable.Insert(netip.MustParsePrefix("192.168.1.0/24"), true) - ignoreDstTable.Insert(netip.MustParsePrefix("10.0.0.0/8"), true) + ignoreDstTable := &bart.Lite{} + ignoreDstTable.Insert(netip.MustParsePrefix("192.168.1.0/24")) + ignoreDstTable.Insert(netip.MustParsePrefix("10.0.0.0/8")) c := &connector{ ignoreDsts: ignoreDstTable, diff --git a/cmd/sniproxy/sniproxy.go b/cmd/sniproxy/sniproxy.go index 45503feca8718..f7ebc6abaa4e5 100644 --- a/cmd/sniproxy/sniproxy.go +++ b/cmd/sniproxy/sniproxy.go @@ -138,9 +138,9 @@ func run(ctx context.Context, ts *tsnet.Server, wgPort int, hostname string, pro } // Finally, start mainloop to configure app connector based on information - // in the netmap. - // We set the NotifyInitialNetMap flag so we will always get woken with the - // current netmap, before only being woken on changes. + // in the self node's CapMap. We set NotifyInitialNetMap so the first + // Notify carries the current self node (now via Notify.SelfChange); + // subsequent self changes wake us up too. bus, err := lc.WatchIPNBus(ctx, ipn.NotifyWatchEngineUpdates|ipn.NotifyInitialNetMap) if err != nil { log.Fatalf("watching IPN bus: %v", err) @@ -155,28 +155,30 @@ func run(ctx context.Context, ts *tsnet.Server, wgPort int, hostname string, pro log.Fatalf("reading IPN bus: %v", err) } - // NetMap contains app-connector configuration - if nm := msg.NetMap; nm != nil && nm.SelfNode.Valid() { - var c appctype.AppConnectorConfig - nmConf, err := tailcfg.UnmarshalNodeCapViewJSON[appctype.AppConnectorConfig](nm.SelfNode.CapMap(), configCapKey) - if err != nil { - log.Printf("failed to read app connector configuration from coordination server: %v", err) - } else if len(nmConf) > 0 { - c = nmConf[0] - } + self := msg.SelfChange + if self == nil { + continue + } + var c appctype.AppConnectorConfig + // View() lets us reuse the existing CapView decoder. + nmConf, err := tailcfg.UnmarshalNodeCapViewJSON[appctype.AppConnectorConfig](self.View().CapMap(), configCapKey) + if err != nil { + log.Printf("failed to read app connector configuration from coordination server: %v", err) + } else if len(nmConf) > 0 { + c = nmConf[0] + } - if c.AdvertiseRoutes { - if err := s.advertiseRoutesFromConfig(ctx, &c); err != nil { - log.Printf("failed to advertise routes: %v", err) - } + if c.AdvertiseRoutes { + if err := s.advertiseRoutesFromConfig(ctx, &c); err != nil { + log.Printf("failed to advertise routes: %v", err) } - - // Backwards compatibility: combine any configuration from control with flags specified - // on the command line. This is intentionally done after we advertise any routes - // because its never correct to advertise the nodes native IP addresses. - s.mergeConfigFromFlags(&c, ports, forwards) - s.srv.Configure(&c) } + + // Backwards compatibility: combine any configuration from control with flags specified + // on the command line. This is intentionally done after we advertise any routes + // because its never correct to advertise the nodes native IP addresses. + s.mergeConfigFromFlags(&c, ports, forwards) + s.srv.Configure(&c) } } @@ -225,7 +227,7 @@ func (s *sniproxy) mergeConfigFromFlags(out *appctype.AppConnectorConfig, ports, Addrs: []netip.Addr{ip4, ip6}, } if ports != "" { - for _, portStr := range strings.Split(ports, ",") { + for portStr := range strings.SplitSeq(ports, ",") { port, err := strconv.ParseUint(portStr, 10, 16) if err != nil { log.Fatalf("invalid port: %s", portStr) @@ -238,7 +240,7 @@ func (s *sniproxy) mergeConfigFromFlags(out *appctype.AppConnectorConfig, ports, } var forwardConfigFromFlags []appctype.DNATConfig - for _, forwStr := range strings.Split(forwards, ",") { + for forwStr := range strings.SplitSeq(forwards, ",") { if forwStr == "" { continue } diff --git a/cmd/speedtest/speedtest.go b/cmd/speedtest/speedtest.go index 2cea97b1edef1..e11c4ad1d90bb 100644 --- a/cmd/speedtest/speedtest.go +++ b/cmd/speedtest/speedtest.go @@ -72,8 +72,7 @@ var speedtestArgs struct { func runSpeedtest(ctx context.Context, args []string) error { if _, _, err := net.SplitHostPort(speedtestArgs.host); err != nil { - var addrErr *net.AddrError - if errors.As(err, &addrErr) && addrErr.Err == "missing port in address" { + if addrErr, ok := errors.AsType[*net.AddrError](err); ok && addrErr.Err == "missing port in address" { // if no port is provided, append the default port speedtestArgs.host = net.JoinHostPort(speedtestArgs.host, strconv.Itoa(speedtest.DefaultPort)) } diff --git a/cmd/ssh-auth-none-demo/ssh-auth-none-demo.go b/cmd/ssh-auth-none-demo/ssh-auth-none-demo.go index 3c3ade3cd35a3..a2cd3acd2e399 100644 --- a/cmd/ssh-auth-none-demo/ssh-auth-none-demo.go +++ b/cmd/ssh-auth-none-demo/ssh-auth-none-demo.go @@ -28,8 +28,8 @@ import ( "path/filepath" "time" - gossh "golang.org/x/crypto/ssh" - "tailscale.com/tempfork/gliderlabs/ssh" + gliderssh "github.com/tailscale/gliderssh" + "golang.org/x/crypto/ssh" ) // keyTypes are the SSH key types that we either try to read from the @@ -60,23 +60,23 @@ func main() { log.Fatal("no host keys") } - srv := &ssh.Server{ + srv := &gliderssh.Server{ Addr: *addr, Version: "Tailscale", Handler: handleSessionPostSSHAuth, - ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig { + ServerConfigCallback: func(ctx gliderssh.Context) *ssh.ServerConfig { start := time.Now() - var spac gossh.ServerPreAuthConn - return &gossh.ServerConfig{ - PreAuthConnCallback: func(conn gossh.ServerPreAuthConn) { + var spac ssh.ServerPreAuthConn + return &ssh.ServerConfig{ + PreAuthConnCallback: func(conn ssh.ServerPreAuthConn) { spac = conn }, NoClientAuth: true, // required for the NoClientAuthCallback to run - NoClientAuthCallback: func(cm gossh.ConnMetadata) (*gossh.Permissions, error) { + NoClientAuthCallback: func(cm ssh.ConnMetadata) (*ssh.Permissions, error) { spac.SendAuthBanner(fmt.Sprintf("# Banner: doing none auth at %v\r\n", time.Since(start))) if cm.User() == "denyme" { - return nil, &gossh.BannerError{ + return nil, &ssh.BannerError{ Err: errors.New("denying access"), Message: "denyme is not allowed to access this machine\n", } @@ -96,7 +96,7 @@ func main() { } return nil, nil }, - BannerCallback: func(cm gossh.ConnMetadata) string { + BannerCallback: func(cm ssh.ConnMetadata) string { log.Printf("Got connection from user %q, %q from %v", cm.User(), cm.ClientVersion(), cm.RemoteAddr()) return fmt.Sprintf("# Banner for user %q, %q\n", cm.User(), cm.ClientVersion()) }, @@ -115,7 +115,7 @@ func main() { log.Printf("done") } -func handleSessionPostSSHAuth(s ssh.Session) { +func handleSessionPostSSHAuth(s gliderssh.Session) { log.Printf("Started session from user %q", s.User()) fmt.Fprintf(s, "Hello user %q, it worked.\n", s.User()) @@ -143,13 +143,13 @@ func handleSessionPostSSHAuth(s ssh.Session) { s.Exit(0) } -func getHostKeys(dir string) (ret []ssh.Signer, err error) { +func getHostKeys(dir string) (ret []gliderssh.Signer, err error) { for _, typ := range keyTypes { hostKey, err := hostKeyFileOrCreate(dir, typ) if err != nil { return nil, err } - signer, err := gossh.ParsePrivateKey(hostKey) + signer, err := ssh.ParsePrivateKey(hostKey) if err != nil { return nil, err } diff --git a/cmd/stund/depaware.txt b/cmd/stund/depaware.txt index d25974b2df424..7804915dc7e05 100644 --- a/cmd/stund/depaware.txt +++ b/cmd/stund/depaware.txt @@ -71,11 +71,11 @@ tailscale.com/cmd/stund dependencies: (generated by github.com/tailscale/depawar tailscale.com/types/logger from tailscale.com/tsweb+ tailscale.com/types/opt from tailscale.com/envknob+ tailscale.com/types/persist from tailscale.com/feature - tailscale.com/types/ptr from tailscale.com/tailcfg+ tailscale.com/types/result from tailscale.com/util/lineiter tailscale.com/types/structs from tailscale.com/tailcfg+ tailscale.com/types/tkatype from tailscale.com/tailcfg+ tailscale.com/types/views from tailscale.com/net/tsaddr+ + tailscale.com/util/bufiox from tailscale.com/types/key tailscale.com/util/ctxkey from tailscale.com/tsweb+ L đŸ’Ŗ tailscale.com/util/dirwalk from tailscale.com/metrics tailscale.com/util/dnsname from tailscale.com/tailcfg diff --git a/cmd/stunstamp/stunstamp.go b/cmd/stunstamp/stunstamp.go index cfedd82bdd5cc..743d6aec3c9d8 100644 --- a/cmd/stunstamp/stunstamp.go +++ b/cmd/stunstamp/stunstamp.go @@ -889,8 +889,7 @@ func remoteWriteTimeSeries(client *remoteWriteClient, tsCh chan []prompb.TimeSer reqCtx, cancel := context.WithTimeout(context.Background(), time.Second*30) writeErr = client.write(reqCtx, ts) cancel() - var re recoverableErr - recoverable := errors.As(writeErr, &re) + _, recoverable := errors.AsType[recoverableErr](writeErr) if writeErr != nil { log.Printf("remote write error(recoverable=%v): %v", recoverable, writeErr) } diff --git a/cmd/systray/systray.go b/cmd/systray/systray.go index 9dc35f1420bee..68a3397820274 100644 --- a/cmd/systray/systray.go +++ b/cmd/systray/systray.go @@ -15,9 +15,11 @@ import ( ) var socket = flag.String("socket", paths.DefaultTailscaledSocket(), "path to tailscaled socket") +var theme = flag.String("theme", "dark", "color theme for Tailscale icon: dark, dark:nobg, light, light:nobg") func main() { flag.Parse() lc := &local.Client{Socket: *socket} + systray.SetTheme(*theme) new(systray.Menu).Run(lc) } diff --git a/cmd/tailscale/cli/appcroutes.go b/cmd/tailscale/cli/appcroutes.go index 2ea001aec9c84..04cbcdd832258 100644 --- a/cmd/tailscale/cli/appcroutes.go +++ b/cmd/tailscale/cli/appcroutes.go @@ -102,12 +102,12 @@ func getSummarizeLearnedOutput(ri *appctype.RouteInfo) string { } return 0 }) - s := "" + var s strings.Builder fmtString := fmt.Sprintf("%%-%ds %%d\n", maxDomainWidth) // eg "%-10s %d\n" for _, dc := range x { - s += fmt.Sprintf(fmtString, dc.domain, dc.count) + s.WriteString(fmt.Sprintf(fmtString, dc.domain, dc.count)) } - return s + return s.String() } func runAppcRoutesInfo(ctx context.Context, args []string) error { diff --git a/cmd/tailscale/cli/cli.go b/cmd/tailscale/cli/cli.go index fda6b4546324a..a9d7364c275d2 100644 --- a/cmd/tailscale/cli/cli.go +++ b/cmd/tailscale/cli/cli.go @@ -28,6 +28,7 @@ import ( "tailscale.com/feature" "tailscale.com/paths" "tailscale.com/util/slicesx" + "tailscale.com/util/testenv" "tailscale.com/version/distro" ) @@ -124,7 +125,7 @@ func Run(args []string) (err error) { if errors.Is(err, flag.ErrHelp) { return nil } - if noexec := (ffcli.NoExecError{}); errors.As(err, &noexec) { + if noexec, ok := errors.AsType[ffcli.NoExecError](err); ok { // When the user enters an unknown subcommand, ffcli tries to run // the closest valid parent subcommand with everything else as args, // returning NoExecError if it doesn't have an Exec function. @@ -194,17 +195,39 @@ func (v *onceFlagValue) IsBoolFlag() bool { return ok && bf.IsBoolFlag() } -// noDupFlagify modifies c recursively to make all the -// flag values be wrappers that permit setting the value -// at most once. -func noDupFlagify(c *ffcli.Command) { - if c.FlagSet != nil { - c.FlagSet.VisitAll(func(f *flag.Flag) { - f.Value = &onceFlagValue{Value: f.Value} - }) +// noDupFlagify modifies c recursively to make all the flag values be +// wrappers that permit setting the value at most once. If tb is +// non-nil, the original values are restored when the test completes. +func noDupFlagify(c *ffcli.Command, tb testenv.TB) { + if tb == nil && testenv.InTest() { + return } - for _, sub := range c.Subcommands { - noDupFlagify(sub) + type restore struct { + f *flag.Flag + v flag.Value + } + var restores []restore + var walk func(*ffcli.Command) + walk = func(c *ffcli.Command) { + if c.FlagSet != nil { + c.FlagSet.VisitAll(func(f *flag.Flag) { + if tb != nil { + restores = append(restores, restore{f, f.Value}) + } + f.Value = &onceFlagValue{Value: f.Value} + }) + } + for _, sub := range c.Subcommands { + walk(sub) + } + } + walk(c) + if tb != nil { + tb.Cleanup(func() { + for _, r := range restores { + r.f.Value = r.v + } + }) } } @@ -221,7 +244,7 @@ var ( _ func() *ffcli.Command ) -func newRootCmd() *ffcli.Command { +func newRootCmd(tb ...testenv.TB) *ffcli.Command { rootfs := newFlagSet("tailscale") rootfs.Func("socket", "path to tailscaled socket", func(s string) error { localClient.Socket = s @@ -303,7 +326,11 @@ change in the future. }) ffcomplete.Inject(rootCmd, func(c *ffcli.Command) { c.LongHelp = hidden + c.LongHelp }, usageFunc) - noDupFlagify(rootCmd) + var t testenv.TB + if len(tb) > 0 { + t = tb[0] + } + noDupFlagify(rootCmd, t) return rootCmd } diff --git a/cmd/tailscale/cli/cli_test.go b/cmd/tailscale/cli/cli_test.go index 537e641fc4160..d2df825d3786b 100644 --- a/cmd/tailscale/cli/cli_test.go +++ b/cmd/tailscale/cli/cli_test.go @@ -769,7 +769,7 @@ func TestPrefsFromUpArgs(t *testing.T) { args: upArgsT{ exitNodeIP: "foo", }, - wantErr: `invalid value "foo" for --exit-node; must be IP or unique node name`, + wantErr: `invalid value "foo" for --exit-node; must be IP or hostname`, }, { name: "error_exit_node_allow_lan_without_exit_node", @@ -779,11 +779,43 @@ func TestPrefsFromUpArgs(t *testing.T) { wantErr: `--exit-node-allow-lan-access can only be used with --exit-node`, }, { - name: "error_tag_prefix", + name: "error_tag_bad_prefix", args: upArgsT{ - advertiseTags: "foo", + advertiseTags: "notatag:foo", + }, + wantErr: `tag: "notatag:foo": tags must start with 'tag:'`, + }, + { + name: "tag_auto_prefix", + args: upArgsFromOSArgs("linux", "--advertise-tags=foo,bar"), + want: &ipn.Prefs{ + ControlURL: ipn.DefaultControlURL, + WantRunning: true, + CorpDNS: true, + AdvertiseTags: []string{"tag:foo", "tag:bar"}, + NoSNAT: false, + NoStatefulFiltering: "true", + NetfilterMode: preftype.NetfilterOn, + AutoUpdate: ipn.AutoUpdatePrefs{ + Check: true, + }, + }, + }, + { + name: "tag_mixed_prefix", + args: upArgsFromOSArgs("linux", "--advertise-tags=tag:foo,bar"), + want: &ipn.Prefs{ + ControlURL: ipn.DefaultControlURL, + WantRunning: true, + CorpDNS: true, + AdvertiseTags: []string{"tag:foo", "tag:bar"}, + NoSNAT: false, + NoStatefulFiltering: "true", + NetfilterMode: preftype.NetfilterOn, + AutoUpdate: ipn.AutoUpdatePrefs{ + Check: true, + }, }, - wantErr: `tag: "foo": tags must start with 'tag:'`, }, { name: "error_long_hostname", @@ -962,8 +994,8 @@ func TestPrefFlagMapping(t *testing.T) { } prefType := reflect.TypeFor[ipn.Prefs]() - for i := range prefType.NumField() { - prefName := prefType.Field(i).Name + for field := range prefType.Fields() { + prefName := field.Name if prefHasFlag[prefName] { continue } @@ -1533,13 +1565,13 @@ func TestParseNLArgs(t *testing.T) { parseDisablements: true, }, { - name: "key no votes", + name: "key-no-votes", input: []string{"nlpub:" + strings.Repeat("00", 32)}, parseKeys: true, wantKeys: []tka.Key{{Kind: tka.Key25519, Votes: 1, Public: bytes.Repeat([]byte{0}, 32)}}, }, { - name: "key with votes", + name: "key-with-votes", input: []string{"nlpub:" + strings.Repeat("01", 32) + "?5"}, parseKeys: true, wantKeys: []tka.Key{{Kind: tka.Key25519, Votes: 5, Public: bytes.Repeat([]byte{1}, 32)}}, @@ -1551,13 +1583,13 @@ func TestParseNLArgs(t *testing.T) { wantDisablements: [][]byte{bytes.Repeat([]byte{2}, 32), bytes.Repeat([]byte{3}, 32)}, }, { - name: "disablements not allowed", + name: "disablements-not-allowed", input: []string{"disablement:" + strings.Repeat("02", 32)}, parseKeys: true, wantErr: fmt.Errorf("parsing key 1: key hex string doesn't have expected type prefix tlpub:"), }, { - name: "keys not allowed", + name: "keys-not-allowed", input: []string{"nlpub:" + strings.Repeat("02", 32)}, parseDisablements: true, wantErr: fmt.Errorf("parsing argument 1: expected value with \"disablement:\" or \"disablement-secret:\" prefix, got %q", "nlpub:0202020202020202020202020202020202020202020202020202020202020202"), @@ -1618,7 +1650,7 @@ func TestNoDups(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - cmd := newRootCmd() + cmd := newRootCmd(t) makeQuietContinueOnError(cmd) err := cmd.Parse(tt.args) if got := fmt.Sprint(err); got != tt.want { diff --git a/cmd/tailscale/cli/configure-kube.go b/cmd/tailscale/cli/configure-kube.go index 3dcec250f01ef..8160025c6858e 100644 --- a/cmd/tailscale/cli/configure-kube.go +++ b/cmd/tailscale/cli/configure-kube.go @@ -20,10 +20,8 @@ import ( "github.com/peterbourgon/ff/v3/ffcli" "k8s.io/client-go/util/homedir" "sigs.k8s.io/yaml" - "tailscale.com/ipn" "tailscale.com/ipn/ipnstate" "tailscale.com/tailcfg" - "tailscale.com/types/netmap" "tailscale.com/util/dnsname" "tailscale.com/version" ) @@ -98,12 +96,12 @@ func runConfigureKubeconfig(ctx context.Context, args []string) error { if st.BackendState != "Running" { return errors.New("Tailscale is not running") } - nm, err := getNetMap(ctx) + dnsCfg, err := getDNSConfig(ctx) if err != nil { return err } - targetFQDN, err := nodeOrServiceDNSNameFromArg(st, nm, hostOrFQDNOrIP) + targetFQDN, err := nodeOrServiceDNSNameFromArg(st, dnsCfg, hostOrFQDNOrIP) if err != nil { return err } @@ -240,14 +238,14 @@ func setKubeconfigForPeer(scheme, fqdn, filePath string) error { // nodeOrServiceDNSNameFromArg returns the PeerStatus.DNSName value from a peer // in st that matches the input arg which can be a base name, full DNS name, or // an IP. If none is found, it looks for a Tailscale Service -func nodeOrServiceDNSNameFromArg(st *ipnstate.Status, nm *netmap.NetworkMap, arg string) (string, error) { +func nodeOrServiceDNSNameFromArg(st *ipnstate.Status, dns *tailcfg.DNSConfig, arg string) (string, error) { // First check for a node DNS name. if dnsName, ok := nodeDNSNameFromArg(st, arg); ok { return dnsName, nil } // If not found, check for a Tailscale Service DNS name. - rec, ok := serviceDNSRecordFromNetMap(nm, arg) + rec, ok := serviceDNSRecordFromDNSConfig(dns, arg) if !ok { return "", fmt.Errorf("no peer found for %q", arg) } @@ -269,25 +267,13 @@ func nodeOrServiceDNSNameFromArg(st *ipnstate.Status, nm *netmap.NetworkMap, arg return "", fmt.Errorf("%q is in MagicDNS, but is not currently reachable on any known peer", arg) } -func getNetMap(ctx context.Context) (*netmap.NetworkMap, error) { +func getDNSConfig(ctx context.Context) (*tailcfg.DNSConfig, error) { ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() - - watcher, err := localClient.WatchIPNBus(ctx, ipn.NotifyInitialNetMap) - if err != nil { - return nil, err - } - defer watcher.Close() - - n, err := watcher.Next() - if err != nil { - return nil, err - } - - return n.NetMap, nil + return localClient.DNSConfig(ctx) } -func serviceDNSRecordFromNetMap(nm *netmap.NetworkMap, arg string) (rec tailcfg.DNSRecord, ok bool) { +func serviceDNSRecordFromDNSConfig(dns *tailcfg.DNSConfig, arg string) (rec tailcfg.DNSRecord, ok bool) { argIP, _ := netip.ParseAddr(arg) argFQDN, err := dnsname.ToFQDN(arg) argFQDNValid := err == nil @@ -295,7 +281,7 @@ func serviceDNSRecordFromNetMap(nm *netmap.NetworkMap, arg string) (rec tailcfg. return rec, false } - for _, rec := range nm.DNS.ExtraRecords { + for _, rec := range dns.ExtraRecords { if argIP.IsValid() { recIP, _ := netip.ParseAddr(rec.Value) if recIP == argIP { diff --git a/cmd/tailscale/cli/configure-kube_test.go b/cmd/tailscale/cli/configure-kube_test.go index 2df54d5751497..2c2a05ac0c08f 100644 --- a/cmd/tailscale/cli/configure-kube_test.go +++ b/cmd/tailscale/cli/configure-kube_test.go @@ -76,7 +76,7 @@ users: token: unused`, }, { - name: "all configs, clusters, users have been deleted", + name: "all-configs-clusters-users-deleted", in: `apiVersion: v1 clusters: null contexts: null diff --git a/cmd/tailscale/cli/configure-synology-cert.go b/cmd/tailscale/cli/configure-synology-cert.go index 0f38f2df2941c..32f5bbd70593c 100644 --- a/cmd/tailscale/cli/configure-synology-cert.go +++ b/cmd/tailscale/cli/configure-synology-cert.go @@ -16,6 +16,7 @@ import ( "os/exec" "path" "runtime" + "slices" "strings" "github.com/peterbourgon/ff/v3/ffcli" @@ -85,11 +86,8 @@ func runConfigureSynologyCert(ctx context.Context, args []string) error { domain = st.CertDomains[0] } else { var found bool - for _, d := range st.CertDomains { - if d == domain { - found = true - break - } + if slices.Contains(st.CertDomains, domain) { + found = true } if !found { return fmt.Errorf("Domain %q was not one of the valid domain options: %q.", domain, st.CertDomains) diff --git a/cmd/tailscale/cli/configure-synology-cert_test.go b/cmd/tailscale/cli/configure-synology-cert_test.go index 08369c135f154..d79ceb9d362b8 100644 --- a/cmd/tailscale/cli/configure-synology-cert_test.go +++ b/cmd/tailscale/cli/configure-synology-cert_test.go @@ -30,7 +30,7 @@ func Test_listCerts(t *testing.T) { wantErr bool }{ { - name: "normal response", + name: "normal-response", caller: fakeAPICaller{ Data: json.RawMessage(`{ "certificates" : [ @@ -117,12 +117,12 @@ func Test_listCerts(t *testing.T) { }, }, { - name: "call error", + name: "call-error", caller: fakeAPICaller{nil, fmt.Errorf("caller failed")}, wantErr: true, }, { - name: "payload decode error", + name: "payload-decode-error", caller: fakeAPICaller{json.RawMessage("This isn't JSON!"), nil}, wantErr: true, }, diff --git a/cmd/tailscale/cli/configure_linux.go b/cmd/tailscale/cli/configure_linux.go index 9ba3b8e878d52..da04449087558 100644 --- a/cmd/tailscale/cli/configure_linux.go +++ b/cmd/tailscale/cli/configure_linux.go @@ -18,7 +18,7 @@ func init() { maybeSystrayCmd = systrayConfigCmd } -var systrayArgs struct { +var configSystrayArgs struct { initSystem string installStartup bool } @@ -32,7 +32,7 @@ func systrayConfigCmd() *ffcli.Command { Exec: configureSystray, FlagSet: (func() *flag.FlagSet { fs := newFlagSet("systray") - fs.StringVar(&systrayArgs.initSystem, "enable-startup", "", + fs.StringVar(&configSystrayArgs.initSystem, "enable-startup", "", "Install startup script for init system. Currently supported systems are [systemd, freedesktop].") return fs })(), @@ -40,8 +40,8 @@ func systrayConfigCmd() *ffcli.Command { } func configureSystray(_ context.Context, _ []string) error { - if systrayArgs.initSystem != "" { - if err := systray.InstallStartupScript(systrayArgs.initSystem); err != nil { + if configSystrayArgs.initSystem != "" { + if err := systray.InstallStartupScript(configSystrayArgs.initSystem); err != nil { fmt.Printf("%s\n\n", err.Error()) return flag.ErrHelp } diff --git a/cmd/tailscale/cli/debug-cachenetmap.go b/cmd/tailscale/cli/debug-cachenetmap.go new file mode 100644 index 0000000000000..735469ee42c3d --- /dev/null +++ b/cmd/tailscale/cli/debug-cachenetmap.go @@ -0,0 +1,31 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ios && !ts_omit_cachenetmap + +package cli + +import ( + "context" + "errors" + + "github.com/peterbourgon/ff/v3/ffcli" +) + +func init() { + debugClearNetmapCacheCmd = func() *ffcli.Command { + return &ffcli.Command{ + Name: "clear-netmap-cache", + ShortUsage: "tailscale debug clear-netmap-cache", + ShortHelp: "Remove and discard cached network maps (if any)", + Exec: runDebugClearNetmapCache, + } + } +} + +func runDebugClearNetmapCache(ctx context.Context, args []string) error { + if len(args) != 0 { + return errors.New("unexpected arguments") + } + return localClient.DebugAction(ctx, "clear-netmap-cache") +} diff --git a/cmd/tailscale/cli/debug.go b/cmd/tailscale/cli/debug.go index 629c694c0c6b4..3531172bbf1f6 100644 --- a/cmd/tailscale/cli/debug.go +++ b/cmd/tailscale/cli/debug.go @@ -51,9 +51,10 @@ import ( ) var ( - debugCaptureCmd func() *ffcli.Command // or nil - debugPortmapCmd func() *ffcli.Command // or nil - debugPeerRelayCmd func() *ffcli.Command // or nil + debugCaptureCmd func() *ffcli.Command // or nil + debugPortmapCmd func() *ffcli.Command // or nil + debugPeerRelayCmd func() *ffcli.Command // or nil + debugClearNetmapCacheCmd func() *ffcli.Command // or nil ) func debugCmd() *ffcli.Command { @@ -387,7 +388,14 @@ func debugCmd() *ffcli.Command { return fs })(), }, + { + Name: "statedir", + ShortUsage: "tailscale debug statedir", + ShortHelp: "Print the location of the state directory (if any)", + Exec: runPrintStateDir, + }, ccall(debugPeerRelayCmd), + ccall(debugClearNetmapCacheCmd), }...), } } @@ -662,18 +670,11 @@ func runNetmap(ctx context.Context, args []string) error { ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() - var mask ipn.NotifyWatchOpt = ipn.NotifyInitialNetMap - watcher, err := localClient.WatchIPNBus(ctx, mask) - if err != nil { - return err - } - defer watcher.Close() - - n, err := watcher.Next() + raw, err := localClient.DebugResultJSON(ctx, "current-netmap") if err != nil { return err } - j, _ := json.MarshalIndent(n.NetMap, "", "\t") + j, _ := json.MarshalIndent(raw, "", "\t") fmt.Printf("%s\n", j) return nil } @@ -1407,3 +1408,22 @@ func runTestRisk(ctx context.Context, args []string) error { fmt.Println("did-test-risky-action") return nil } + +func runPrintStateDir(ctx context.Context, args []string) error { + if len(args) > 0 { + return errors.New("unexpected arguments") + } + v, err := localClient.DebugResultJSON(ctx, "statedir") + if err != nil { + return err + } + statedir, ok := v.(string) + if ok && statedir != "" { + fmt.Println(statedir) + return nil + } else if ok && statedir == "" { + return errors.New("no statedir is set") + } else { + return fmt.Errorf("got unexpected response from debug API: %v", v) + } +} diff --git a/cmd/tailscale/cli/dns-status.go b/cmd/tailscale/cli/dns-status.go index 66a5e21d89700..91a62f996cc54 100644 --- a/cmd/tailscale/cli/dns-status.go +++ b/cmd/tailscale/cli/dns-status.go @@ -14,9 +14,7 @@ import ( "github.com/peterbourgon/ff/v3/ffcli" "tailscale.com/cmd/tailscale/cli/jsonoutput" - "tailscale.com/ipn" "tailscale.com/types/dnstype" - "tailscale.com/types/netmap" ) var dnsStatusCmd = &ffcli.Command{ @@ -120,11 +118,10 @@ func runDNSStatus(ctx context.Context, args []string) error { SelfDNSName: s.Self.DNSName, } - netMap, err := fetchNetMap() + dnsConfig, err := localClient.DNSConfig(ctx) if err != nil { - return fmt.Errorf("failed to fetch network map: %w", err) + return fmt.Errorf("failed to fetch DNS config: %w", err) } - dnsConfig := netMap.DNS for _, r := range dnsConfig.Resolvers { data.Resolvers = append(data.Resolvers, makeDNSResolverInfo(r)) @@ -357,19 +354,3 @@ func formatDNSStatusText(data *jsonoutput.DNSStatusResult, all bool) string { fmt.Fprintf(&sb, "[this is a preliminary version of this command; the output format may change in the future]\n") return sb.String() } - -func fetchNetMap() (netMap *netmap.NetworkMap, err error) { - w, err := localClient.WatchIPNBus(context.Background(), ipn.NotifyInitialNetMap) - if err != nil { - return nil, err - } - defer w.Close() - notify, err := w.Next() - if err != nil { - return nil, err - } - if notify.NetMap == nil { - return nil, fmt.Errorf("no network map yet available, please try again later") - } - return notify.NetMap, nil -} diff --git a/cmd/tailscale/cli/drive_macgui.go b/cmd/tailscale/cli/drive_macgui.go new file mode 100644 index 0000000000000..8a4594f86c4c9 --- /dev/null +++ b/cmd/tailscale/cli/drive_macgui.go @@ -0,0 +1,33 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_drive && ts_mac_gui + +package cli + +import ( + "context" + "errors" + + "github.com/peterbourgon/ff/v3/ffcli" +) + +func init() { + maybeDriveCmd = driveCmdStub +} + +func driveCmdStub() *ffcli.Command { + return &ffcli.Command{ + Name: "drive", + ShortHelp: "Share a directory with your tailnet", + ShortUsage: "tailscale drive [...any]", + LongHelp: hidden + "Taildrive allows you to share directories with other machines on your tailnet.", + Exec: func(_ context.Context, args []string) error { + return errors.New( + "Taildrive CLI commands are not supported when using the macOS GUI app. " + + "Please use the Tailscale menu bar icon to configure Taildrive in Settings.\n\n" + + "See https://tailscale.com/docs/features/taildrive", + ) + }, + } +} diff --git a/cmd/tailscale/cli/drive_macgui_test.go b/cmd/tailscale/cli/drive_macgui_test.go new file mode 100644 index 0000000000000..11f72b13a578a --- /dev/null +++ b/cmd/tailscale/cli/drive_macgui_test.go @@ -0,0 +1,66 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_drive && ts_mac_gui + +package cli + +import ( + "bytes" + "context" + "flag" + "strings" + "testing" + + "github.com/peterbourgon/ff/v3/ffcli" +) + +// In macOS GUI builds, the `drive` command should not appear in +// the help text generated by ffcli. +func TestDriveCommandHiddenInHelpText(t *testing.T) { + root := newRootCmd() + + var buf bytes.Buffer + root.FlagSet = flag.NewFlagSet("tailscale", flag.ContinueOnError) + root.FlagSet.SetOutput(&buf) + + ffcli.DefaultUsageFunc(root) + + output := buf.String() + + if strings.Contains(output, "drive") { + t.Errorf("found hidden command 'drive' in help output:\n%q", output) + } +} + +// Running the drive command always prints an error pointing you to +// the GUI app, regardless of input. +func TestDriveCommandPrintsError(t *testing.T) { + commands := [][]string{ + {"drive"}, + {"drive", "share", "myfile.txt", "/path/to/myfile.txt"}, + {"drive", "rename", "oldname.txt", "newname.txt"}, + {"drive", "unshare", "myfile.txt"}, + {"drive", "list"}, + } + + for _, args := range commands { + root := newRootCmd() + + if err := root.Parse(args); err != nil { + t.Errorf("unable to parse args %q, got err %v", args, err) + continue + } + + t.Logf("running `tailscale drive %q`", strings.Join(args, " ")) + err := root.Run(context.Background()) + if err == nil { + t.Error("expected error, but got nil", args) + } + + expectedText := "Taildrive CLI commands are not supported when using the macOS GUI app." + if !strings.Contains(err.Error(), expectedText) { + t.Errorf("error was not expected: want %q, got %q", expectedText, err.Error()) + } + } +} diff --git a/cmd/tailscale/cli/exitnode.go b/cmd/tailscale/cli/exitnode.go index 0445b66ae14ff..7ba4859d79463 100644 --- a/cmd/tailscale/cli/exitnode.go +++ b/cmd/tailscale/cli/exitnode.go @@ -138,7 +138,7 @@ func runExitNodeList(ctx context.Context, args []string) error { fmt.Fprintln(w) fmt.Fprintln(w) fmt.Fprintln(w, "# To view the complete list of exit nodes for a country, use `tailscale exit-node list --filter=` followed by the country name.") - fmt.Fprintln(w, "# To use an exit node, use `tailscale set --exit-node=` followed by the hostname or IP.") + fmt.Fprintln(w, "# To use an exit node, use `tailscale set --exit-node=` followed by the IP or hostname.") if hasAnyExitNodeSuggestions(peers) { fmt.Fprintln(w, "# To have Tailscale suggest an exit node, use `tailscale exit-node suggest`.") } diff --git a/cmd/tailscale/cli/exitnode_test.go b/cmd/tailscale/cli/exitnode_test.go index 9a77cf5d7d3fd..d7906b929ff57 100644 --- a/cmd/tailscale/cli/exitnode_test.go +++ b/cmd/tailscale/cli/exitnode_test.go @@ -14,7 +14,7 @@ import ( ) func TestFilterFormatAndSortExitNodes(t *testing.T) { - t.Run("without filter", func(t *testing.T) { + t.Run("without-filter", func(t *testing.T) { ps := []*ipnstate.PeerStatus{ { HostName: "everest-1", @@ -139,7 +139,7 @@ func TestFilterFormatAndSortExitNodes(t *testing.T) { } }) - t.Run("with country filter", func(t *testing.T) { + t.Run("with-country-filter", func(t *testing.T) { ps := []*ipnstate.PeerStatus{ { HostName: "baker-1", diff --git a/cmd/tailscale/cli/file.go b/cmd/tailscale/cli/file.go index 94b36f535bcab..489c83deb4fed 100644 --- a/cmd/tailscale/cli/file.go +++ b/cmd/tailscale/cli/file.go @@ -19,6 +19,7 @@ import ( "os" "path" "path/filepath" + "slices" "strings" "sync" "sync/atomic" @@ -31,6 +32,7 @@ import ( "tailscale.com/client/tailscale/apitype" "tailscale.com/cmd/tailscale/cli/ffcomplete" "tailscale.com/envknob" + "tailscale.com/ipn" "tailscale.com/ipn/ipnstate" "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" @@ -77,14 +79,16 @@ var fileCpCmd = &ffcli.Command{ fs.StringVar(&cpArgs.name, "name", "", "alternate filename to use, especially useful when is \"-\" (stdin)") fs.BoolVar(&cpArgs.verbose, "verbose", false, "verbose output") fs.BoolVar(&cpArgs.targets, "targets", false, "list possible file cp targets") + fs.DurationVar(&cpArgs.updateInterval, "update-interval", 250*time.Millisecond, "how often to repaint the progress line; zero or negative disables progress display entirely") return fs })(), } var cpArgs struct { - name string - verbose bool - targets bool + name string + verbose bool + targets bool + updateInterval time.Duration } func runCp(ctx context.Context, args []string) error { @@ -118,22 +122,61 @@ func runCp(ctx context.Context, args []string) error { if err != nil { return fmt.Errorf("can't send to %s: %v", target, err) } - if isOffline { - fmt.Fprintf(Stderr, "# warning: %s is offline\n", target) - } if len(files) > 1 { if cpArgs.name != "" { return errors.New("can't use --name= with multiple files") } - for _, fileArg := range files { - if fileArg == "-" { - return errors.New("can't use '-' as STDIN file when providing filename arguments") + if slices.Contains(files, "-") { + return errors.New("can't use '-' as STDIN file when providing filename arguments") + } + } + + // outFiles tracks per-name push state, populated by a goroutine subscribed + // to the IPN bus. tailscaled's OutgoingFile.Sent is the bytes-pulled-toward- + // peerAPI signal; it stays at 0 until the peerAPI request body is actually + // being read, which is what we want both for the progress display and for + // disarming the offline warning. The CLI's local-side bytes counter would + // say "100% sent" the moment net/http buffers a small body into the local + // unix-socket conn to tailscaled, well before the peer has heard a thing. + type pushState struct { + sent atomic.Int64 + warnTimer *time.Timer // disarmed on first byte sent to peerAPI; nil after + } + var ( + outMu sync.Mutex + outFiles = map[string]*pushState{} // keyed by file name + ) + + busCtx, cancelBus := context.WithCancel(ctx) + defer cancelBus() + go watchOutgoingFiles(busCtx, stableID, func(name string, sent int64) { + outMu.Lock() + ps := outFiles[name] + outMu.Unlock() + if ps == nil { + return + } + // Only ever advance ps.sent forward. Bus updates can arrive late + // (after the success path below has already written contentLength + // to ps.sent for an instant final-100% paint), so we'd otherwise + // regress the count and the progress printer would compute a + // negative delta on its next tick. + for { + old := ps.sent.Load() + if sent <= old { + return + } + if ps.sent.CompareAndSwap(old, sent) { + if old == 0 && ps.warnTimer != nil { + ps.warnTimer.Stop() + } + return } } - } + }) - for _, fileArg := range files { + for i, fileArg := range files { var fileContents *countingReader var name = cpArgs.name var contentLength int64 = -1 @@ -176,16 +219,57 @@ func runCp(ctx context.Context, args []string) error { log.Printf("sending %q to %v/%v/%v ...", name, target, ip, stableID) } + // Register this file with the watcher and, for the first file only, + // arm a timer that warns the user if no bytes have flowed to peerAPI + // after a few seconds. The watcher disarms it on first byte; PushFile + // returning also disarms it (cleanup, below). We don't gate on the + // netmap's Online bit (which can lag reality), but we do use it to + // pick between two warning messages. + ps := &pushState{} + if i == 0 { + ps.warnTimer = time.AfterFunc(3*time.Second, func() { + // vtRestartLine clears whatever (possibly progress) was on + // the current line, then we print the warning + \n so the + // next progress redraw lands on a fresh line below. + const vtRestartLine = "\r\x1b[K" + if isOffline { + fmt.Fprintf(Stderr, "%s# warning: %s is reportedly offline; trying anyway\n", vtRestartLine, target) + } else { + fmt.Fprintf(Stderr, "%s# warning: %s is not replying; trying anyway\n", vtRestartLine, target) + } + }) + } + outMu.Lock() + outFiles[name] = ps + outMu.Unlock() + var group sync.WaitGroup ctxProgress, cancelProgress := context.WithCancel(ctx) defer cancelProgress() - if isatty.IsTerminal(os.Stderr.Fd()) { - group.Go(func() { progressPrinter(ctxProgress, name, fileContents.n.Load, contentLength) }) + if cpArgs.updateInterval > 0 && isatty.IsTerminal(os.Stderr.Fd()) { + group.Go(func() { + progressPrinter(ctxProgress, name, ps.sent.Load, contentLength, cpArgs.updateInterval) + }) } err := localClient.PushFile(ctx, stableID, contentLength, name, fileContents) + if err == nil { + // PushFile can finish faster than the IPN bus delivers a final + // OutgoingFile update, leaving the progress display stuck at 0%. + // Synthesize a "fully done" count before stopping the printer so + // its final paint shows 100%. For stdin (contentLength == -1) we + // don't know the size, so fall back to the local read count. + if contentLength >= 0 { + ps.sent.Store(contentLength) + } else { + ps.sent.Store(fileContents.n.Load()) + } + } cancelProgress() group.Wait() // wait for progress printer to stop before reporting the error + if ps.warnTimer != nil { + ps.warnTimer.Stop() + } if err != nil { return err } @@ -196,15 +280,71 @@ func runCp(ctx context.Context, args []string) error { return nil } -func progressPrinter(ctx context.Context, name string, contentCount func() int64, contentLength int64) { +// watchOutgoingFiles subscribes to the IPN bus and invokes onUpdate once +// per OutgoingFile event for files going to peer. It runs until ctx is +// done (which runCp does on return) and is best-effort: if the bus +// subscription fails for any reason, onUpdate simply isn't called and the +// caller's progress display stays at 0 — exactly the right degradation, +// since the warning timer will then fire on its normal 3-second deadline. +func watchOutgoingFiles(ctx context.Context, peer tailcfg.StableNodeID, onUpdate func(name string, sent int64)) { + // NotifyPeerChanges opts in to per-peer add/remove notifications so the + // bus stays responsive without us also subscribing to the full NetMap, + // which we don't read here. + w, err := localClient.WatchIPNBus(ctx, ipn.NotifyInitialOutgoingFiles|ipn.NotifyPeerChanges) + if err != nil { + return + } + defer w.Close() + for { + n, err := w.Next() + if err != nil { + return + } + for _, of := range n.OutgoingFiles { + if of.PeerID != peer { + continue + } + // tailscaled keeps Finished entries in its OutgoingFiles map + // across PushFile calls (see feature/taildrop/ext.go), so a + // re-send of the same filename will see both the old completed + // (Sent == DeclaredSize) entry and the new in-progress one. + // Without this filter the watcher's monotonic CAS would latch + // onto the old entry's max value and the new transfer would + // appear stuck at 100% from the first bus tick. + if of.Finished { + continue + } + onUpdate(of.Name, of.Sent) + } + } +} + +// progressPrinter repaints a single-line transfer progress display every +// interval. interval must be > 0; runCp's caller gates on the +// --update-interval flag and skips invoking us when it's <= 0. +// +// It returns when ctx is done OR when it detects the transfer is stuck — +// "stuck" being: contentCount has equalled contentLength with a near-zero +// rate for >2 seconds. The stuck case prints a final newline so subsequent +// output (e.g. an error from PushFile) lands on a fresh line below the +// frozen progress line, instead of being painted over by it. +func progressPrinter(ctx context.Context, name string, contentCount func() int64, contentLength int64, interval time.Duration) { var rateValueFast, rateValueSlow tsrate.Value - rateValueFast.HalfLife = 1 * time.Second // fast response for rate measurement - rateValueSlow.HalfLife = 10 * time.Second // slow response for ETA measurement + // tailscaled emits OutgoingFile.Sent updates at ~1 Hz, so most printer + // ticks see no delta. With too short a half-life the displayed rate + // roughly halves between updates and doubles back when one arrives, + // looking jumpy. 5s keeps the swing under ~15% while still settling + // within a few seconds of a real change. + rateValueFast.HalfLife = 5 * time.Second // smoothed rate for display + rateValueSlow.HalfLife = 10 * time.Second // even slower, for ETA measurement var prevContentCount int64 print := func() { currContentCount := contentCount() - rateValueFast.Add(float64(currContentCount - prevContentCount)) - rateValueSlow.Add(float64(currContentCount - prevContentCount)) + // Clamp so a regression (which shouldn't happen, but tsrate.Value.Add + // panics on a negative count) can't take down the CLI. + delta := max(currContentCount-prevContentCount, 0) + rateValueFast.Add(float64(delta)) + rateValueSlow.Add(float64(delta)) prevContentCount = currContentCount const vtRestartLine = "\r\x1b[K" @@ -216,16 +356,23 @@ func progressPrinter(ctx context.Context, name string, contentCount func() int64 if contentLength >= 0 { currContentCount = min(currContentCount, contentLength) // cap at 100% ratioRemain := float64(currContentCount) / float64(contentLength) - bytesRemain := float64(contentLength - currContentCount) - secsRemain := bytesRemain / rateValueSlow.Rate() - secs := int(min(max(0, secsRemain), 99*60*60+59+60+59)) + etaStr := "ETA -" + if rate := rateValueSlow.Rate(); rate > 0 { + bytesRemain := float64(contentLength - currContentCount) + secsRemain := bytesRemain / rate + secs := int(min(max(0, secsRemain), 99*60*60+59+60+59)) + etaStr = fmt.Sprintf("ETA %02d:%02d:%02d", secs/60/60, (secs/60)%60, secs%60) + } fmt.Fprintf(os.Stderr, " %s %s", leftPad(fmt.Sprintf("%0.2f%%", 100.0*ratioRemain), len("100.00%")), - fmt.Sprintf("ETA %02d:%02d:%02d", secs/60/60, (secs/60)%60, secs%60)) + etaStr) } } - tc := time.NewTicker(250 * time.Millisecond) + const stuckAfter = 2 * time.Second + var fullStartedAt time.Time // when we first observed currCount==contentLength with ~zero rate + + tc := time.NewTicker(interval) defer tc.Stop() print() for { @@ -236,6 +383,24 @@ func progressPrinter(ctx context.Context, name string, contentCount func() int64 return case <-tc.C: print() + if contentLength < 0 { + continue + } + currCount := contentCount() + rate := rateValueFast.Rate() + if currCount >= contentLength && rate < 1 { + if fullStartedAt.IsZero() { + fullStartedAt = time.Now() + } else if time.Since(fullStartedAt) >= stuckAfter { + // Transfer is stuck at 100% with no movement. Stop + // repainting so we don't keep clobbering anything the + // rest of runCp prints (warnings, errors). + fmt.Fprintln(os.Stderr) + return + } + } else { + fullStartedAt = time.Time{} + } } } } @@ -329,7 +494,10 @@ peerLoop: return "", isOffline, errors.New("cannot send files: missing required Taildrop capability") case ipnstate.TaildropTargetOffline: - return "", isOffline, errors.New("cannot send files: peer is offline") + // Don't gate on the server-reported Online bit (which lags reality + // and isn't always accurate). runCp probes reachability itself with + // TSMP pings. + return foundPeer.ID, isOffline, nil case ipnstate.TaildropTargetNoPeerInfo: return "", isOffline, errors.New("cannot send files: invalid or unrecognized peer") diff --git a/cmd/tailscale/cli/ip.go b/cmd/tailscale/cli/ip.go index 7159904c732d6..b76ef0a708b3a 100644 --- a/cmd/tailscale/cli/ip.go +++ b/cmd/tailscale/cli/ip.go @@ -9,6 +9,7 @@ import ( "flag" "fmt" "net/netip" + "slices" "github.com/peterbourgon/ff/v3/ffcli" "tailscale.com/ipn/ipnstate" @@ -114,17 +115,13 @@ func peerMatchingIP(st *ipnstate.Status, ipStr string) (ps *ipnstate.PeerStatus, return } for _, ps = range st.Peer { - for _, pip := range ps.TailscaleIPs { - if ip == pip { - return ps, true - } + if slices.Contains(ps.TailscaleIPs, ip) { + return ps, true } } if ps := st.Self; ps != nil { - for _, pip := range ps.TailscaleIPs { - if ip == pip { - return ps, true - } + if slices.Contains(ps.TailscaleIPs, ip) { + return ps, true } } return nil, false diff --git a/cmd/tailscale/cli/jsonoutput/network-lock-log.go b/cmd/tailscale/cli/jsonoutput/network-lock-log.go index c3190e6bac9c7..c7c16e223511d 100644 --- a/cmd/tailscale/cli/jsonoutput/network-lock-log.go +++ b/cmd/tailscale/cli/jsonoutput/network-lock-log.go @@ -76,8 +76,8 @@ func toLogMessageV1(aum tka.AUM, update ipnstate.NetworkLockUpdate) logMessageV1 if h := state.LastAUMHash; h != nil { expandedState.LastAUMHash = h.String() } - for _, secret := range state.DisablementSecrets { - expandedState.DisablementSecrets = append(expandedState.DisablementSecrets, fmt.Sprintf("%x", secret)) + for _, secret := range state.DisablementValues { + expandedState.DisablementValues = append(expandedState.DisablementValues, fmt.Sprintf("%x", secret)) } for _, key := range state.Keys { expandedState.Keys = append(expandedState.Keys, toTKAKeyV1(&key)) @@ -180,9 +180,13 @@ type expandedStateV1 struct { // LastAUMHash is the blake2s digest of the last-applied AUM. LastAUMHash string `json:"LastAUMHash,omitzero"` - // DisablementSecrets are KDF-derived values which can be used - // to turn off the TKA in the event of a consensus-breaking bug. - DisablementSecrets []string + // DisablementValues are KDF-derived values used to verify that a caller + // possesses a valid DisablementSecret. These values are used during the + // Tailnet Lock deactivation process. + // + // These are safe to share publicly or store in the clear. They cannot be + // used to derive the original DisablementSecret. + DisablementValues []string // Keys are the public keys of either: // diff --git a/cmd/tailscale/cli/network-lock.go b/cmd/tailscale/cli/network-lock.go index 9ec0e1d7fe819..4febd56a97365 100644 --- a/cmd/tailscale/cli/network-lock.go +++ b/cmd/tailscale/cli/network-lock.go @@ -305,9 +305,7 @@ var nlAddCmd = &ffcli.Command{ Name: "add", ShortUsage: "tailscale lock add ...", ShortHelp: "Add one or more trusted signing keys to tailnet lock", - Exec: func(ctx context.Context, args []string) error { - return runNetworkLockModify(ctx, args, nil) - }, + Exec: runNetworkLockAdd, } var nlRemoveArgs struct { @@ -331,6 +329,9 @@ func runNetworkLockRemove(ctx context.Context, args []string) error { if err != nil { return err } + if len(removeKeys) == 0 { + return fmt.Errorf("missing argument, expected one or more tailnet lock keys") + } st, err := localClient.NetworkLockStatus(ctx) if err != nil { return fixTailscaledConnectError(err) @@ -445,25 +446,24 @@ func parseNLArgs(args []string, parseKeys, parseDisablements bool) (keys []tka.K return keys, disablements, nil } -func runNetworkLockModify(ctx context.Context, addArgs, removeArgs []string) error { - st, err := localClient.NetworkLockStatus(ctx) +func runNetworkLockAdd(ctx context.Context, addArgs []string) error { + addKeys, _, err := parseNLArgs(addArgs, true, false) if err != nil { - return fixTailscaledConnectError(err) + return err } - if !st.Enabled { - return errors.New("tailnet lock is not enabled") + if len(addKeys) == 0 { + return fmt.Errorf("missing argument, expected one or more tailnet lock keys") } - addKeys, _, err := parseNLArgs(addArgs, true, false) + st, err := localClient.NetworkLockStatus(ctx) if err != nil { - return err + return fixTailscaledConnectError(err) } - removeKeys, _, err := parseNLArgs(removeArgs, true, false) - if err != nil { - return err + if !st.Enabled { + return errors.New("tailnet lock is not enabled") } - if err := localClient.NetworkLockModify(ctx, addKeys, removeKeys); err != nil { + if err := localClient.NetworkLockModify(ctx, addKeys, nil); err != nil { return err } return nil @@ -672,7 +672,7 @@ func nlDescribeUpdate(update ipnstate.NetworkLockUpdate, color bool) (string, er case tka.AUMCheckpoint.String(): fmt.Fprintln(&stanza, "Disablement values:") - for _, v := range aum.State.DisablementSecrets { + for _, v := range aum.State.DisablementValues { fmt.Fprintf(&stanza, " - %x\n", v) } fmt.Fprintln(&stanza, "Keys:") @@ -819,13 +819,17 @@ Revocation is a multi-step process that requires several signing nodes to ` + "` func runNetworkLockRevokeKeys(ctx context.Context, args []string) error { // First step in the process if !nlRevokeKeysArgs.cosign && !nlRevokeKeysArgs.finish { - removeKeys, _, err := parseNLArgs(args, true, false) + revokeKeys, _, err := parseNLArgs(args, true, false) if err != nil { return err } - keyIDs := make([]tkatype.KeyID, len(removeKeys)) - for i, k := range removeKeys { + if len(revokeKeys) == 0 { + return fmt.Errorf("missing argument, expected one or more tailnet lock keys") + } + + keyIDs := make([]tkatype.KeyID, len(revokeKeys)) + for i, k := range revokeKeys { keyIDs[i], err = k.ID() if err != nil { return fmt.Errorf("generating keyID: %v", err) diff --git a/cmd/tailscale/cli/network-lock_test.go b/cmd/tailscale/cli/network-lock_test.go index 596a51b8a2deb..8e49265bff181 100644 --- a/cmd/tailscale/cli/network-lock_test.go +++ b/cmd/tailscale/cli/network-lock_test.go @@ -54,7 +54,7 @@ func TestNetworkLockLogOutput(t *testing.T) { Meta: map[string]string{"en": "one", "de": "eins", "es": "uno"}, }, }, - DisablementSecrets: [][]byte{ + DisablementValues: [][]byte{ {1, 2, 3}, {4, 5, 6}, {7, 8, 9}, @@ -125,7 +125,7 @@ KeyID: tlpub:0202 "MessageKind": "checkpoint", "PrevAUMHash": "BKVVXHOVBW7Y7YXYTLVVLMNSYG6DS5GVRVSYZLASNU3AQKA732XQ", "State": { - "DisablementSecrets": [ + "DisablementValues": [ "010203", "040506", "070809" diff --git a/cmd/tailscale/cli/risks.go b/cmd/tailscale/cli/risks.go index 1bd128d566125..6f3ebf37bbebe 100644 --- a/cmd/tailscale/cli/risks.go +++ b/cmd/tailscale/cli/risks.go @@ -39,7 +39,7 @@ func registerAcceptRiskFlag(f *flag.FlagSet, acceptedRisks *string) { // isRiskAccepted reports whether riskType is in the comma-separated list of // risks in acceptedRisks. func isRiskAccepted(riskType, acceptedRisks string) bool { - for _, r := range strings.Split(acceptedRisks, ",") { + for r := range strings.SplitSeq(acceptedRisks, ",") { if r == riskType || r == riskAll { return true } diff --git a/cmd/tailscale/cli/serve_legacy.go b/cmd/tailscale/cli/serve_legacy.go index 837d8851368e4..635bcfa3d6fe2 100644 --- a/cmd/tailscale/cli/serve_legacy.go +++ b/cmd/tailscale/cli/serve_legacy.go @@ -848,10 +848,10 @@ func (e *serveEnv) enableFeatureInteractive(ctx context.Context, feature string, e.lc.IncrementCounter(ctx, fmt.Sprintf("%s_enablement_lost_connection", feature), 1) return err } - if nm := n.NetMap; nm != nil && nm.SelfNode.Valid() { + if self := n.SelfChange; self != nil { gotAll := true for _, c := range caps { - if !nm.SelfNode.HasCap(c) { + if _, has := self.CapMap[c]; !has { // The feature is not yet enabled. // Continue blocking until it is. gotAll = false diff --git a/cmd/tailscale/cli/serve_v2.go b/cmd/tailscale/cli/serve_v2.go index 840c47ac66dd1..13f5c09b8eac5 100644 --- a/cmd/tailscale/cli/serve_v2.go +++ b/cmd/tailscale/cli/serve_v2.go @@ -114,8 +114,8 @@ func (u *acceptAppCapsFlag) Set(s string) error { if s == "" { return nil } - appCaps := strings.Split(s, ",") - for _, appCap := range appCaps { + appCaps := strings.SplitSeq(s, ",") + for appCap := range appCaps { appCap = strings.TrimSpace(appCap) if !validAppCap.MatchString(appCap) { return fmt.Errorf("%q does not match the form {domain}/{name}, where domain must be a fully qualified domain name", appCap) @@ -1096,7 +1096,7 @@ func isRemote(target string) bool { target = "tmp://" + target } - // make sure we can parse the target, wether it's a full URL or just a host:port + // make sure we can parse the target, whether it's a full URL or just a host:port u, err := url.ParseRequestURI(target) if err != nil { // If we can't parse the target, it doesn't matter if it's remote or not diff --git a/cmd/tailscale/cli/serve_v2_test.go b/cmd/tailscale/cli/serve_v2_test.go index 7b27de6f2eb26..1d2a8ef86a2e4 100644 --- a/cmd/tailscale/cli/serve_v2_test.go +++ b/cmd/tailscale/cli/serve_v2_test.go @@ -1056,49 +1056,49 @@ func TestSrcTypeFromFlags(t *testing.T) { expectedErr bool }{ { - name: "only http set", + name: "only-http-set", env: &serveEnv{http: 80}, expectedType: serveTypeHTTP, expectedPort: 80, expectedErr: false, }, { - name: "only https set", + name: "only-https-set", env: &serveEnv{https: 10000}, expectedType: serveTypeHTTPS, expectedPort: 10000, expectedErr: false, }, { - name: "only tcp set", + name: "only-tcp-set", env: &serveEnv{tcp: 8000}, expectedType: serveTypeTCP, expectedPort: 8000, expectedErr: false, }, { - name: "only tls-terminated-tcp set", + name: "only-tls-terminated-tcp-set", env: &serveEnv{tlsTerminatedTCP: 8080}, expectedType: serveTypeTLSTerminatedTCP, expectedPort: 8080, expectedErr: false, }, { - name: "defaults to https, port 443", + name: "defaults-to-https-443", env: &serveEnv{}, expectedType: serveTypeHTTPS, expectedPort: 443, expectedErr: false, }, { - name: "defaults to https, port 443 for service", + name: "defaults-to-https-443-for-service", env: &serveEnv{service: "svc:foo"}, expectedType: serveTypeHTTPS, expectedPort: 443, expectedErr: false, }, { - name: "multiple types set", + name: "multiple-types-set", env: &serveEnv{http: 80, https: 443}, expectedPort: 0, expectedErr: true, @@ -1235,19 +1235,20 @@ func TestAcceptSetAppCapsFlag(t *testing.T) { func TestCleanURLPath(t *testing.T) { tests := []struct { + name string input string expected string wantErr bool }{ - {input: "", expected: "/"}, - {input: "/", expected: "/"}, - {input: "/foo", expected: "/foo"}, - {input: "/foo/", expected: "/foo/"}, - {input: "/../bar", wantErr: true}, + {name: "empty", input: "", expected: "/"}, + {name: "slash", input: "/", expected: "/"}, + {name: "foo", input: "/foo", expected: "/foo"}, + {name: "foo-trailing-slash", input: "/foo/", expected: "/foo/"}, + {name: "dotdot-bar", input: "/../bar", wantErr: true}, } for _, tt := range tests { - t.Run(tt.input, func(t *testing.T) { + t.Run(tt.name, func(t *testing.T) { actual, err := cleanURLPath(tt.input) if tt.wantErr == true && err == nil { @@ -1275,18 +1276,18 @@ func TestAddServiceToPrefs(t *testing.T) { expected []string }{ { - name: "add service to empty prefs", + name: "add-service-to-empty-prefs", svcName: "svc:foo", expected: []string{"svc:foo"}, }, { - name: "add service to existing prefs", + name: "add-service-to-existing-prefs", svcName: "svc:bar", startServices: []string{"svc:foo"}, expected: []string{"svc:foo", "svc:bar"}, }, { - name: "add existing service to prefs", + name: "add-existing-service-to-prefs", svcName: "svc:foo", startServices: []string{"svc:foo"}, expected: []string{"svc:foo"}, @@ -1323,18 +1324,18 @@ func TestRemoveServiceFromPrefs(t *testing.T) { expected []string }{ { - name: "remove service from empty prefs", + name: "remove-service-from-empty-prefs", svcName: "svc:foo", expected: []string{}, }, { - name: "remove existing service from prefs", + name: "remove-existing-service-from-prefs", svcName: "svc:foo", startServices: []string{"svc:foo"}, expected: []string{}, }, { - name: "remove service not in prefs", + name: "remove-service-not-in-prefs", svcName: "svc:bar", startServices: []string{"svc:foo"}, expected: []string{"svc:foo"}, @@ -1446,7 +1447,7 @@ func TestMessageForPort(t *testing.T) { }, "\n"), }, { - name: "serve service http", + name: "serve-service-http", subcmd: serve, serveConfig: &ipn.ServeConfig{ Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ @@ -1490,7 +1491,7 @@ func TestMessageForPort(t *testing.T) { }, "\n"), }, { - name: "serve service no capmap", + name: "serve-service-no-capmap", subcmd: serve, serveConfig: &ipn.ServeConfig{ Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ @@ -1534,7 +1535,7 @@ func TestMessageForPort(t *testing.T) { }, "\n"), }, { - name: "serve service https non-default port", + name: "serve-service-https-non-default-port", subcmd: serve, serveConfig: &ipn.ServeConfig{ Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ @@ -1576,7 +1577,7 @@ func TestMessageForPort(t *testing.T) { }, "\n"), }, { - name: "serve service TCPForward", + name: "serve-service-TCPForward", subcmd: serve, serveConfig: &ipn.ServeConfig{ Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ @@ -1613,7 +1614,7 @@ func TestMessageForPort(t *testing.T) { }, "\n"), }, { - name: "serve service Tun", + name: "serve-service-Tun", subcmd: serve, serveConfig: &ipn.ServeConfig{ Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ @@ -1790,7 +1791,7 @@ func TestSetServe(t *testing.T) { expectErr bool }{ { - name: "add new handler", + name: "add-new-handler", desc: "add a new http handler to empty config", cfg: &ipn.ServeConfig{}, dnsName: "foo.test.ts.net", @@ -1810,7 +1811,7 @@ func TestSetServe(t *testing.T) { }, }, { - name: "update http handler", + name: "update-http-handler", desc: "update an existing http handler on the same port to same type", cfg: &ipn.ServeConfig{ TCP: map[uint16]*ipn.TCPPortHandler{80: {HTTP: true}}, @@ -1839,7 +1840,7 @@ func TestSetServe(t *testing.T) { }, }, { - name: "update TCP handler", + name: "update-TCP-handler", desc: "update an existing TCP handler on the same port to a http handler", cfg: &ipn.ServeConfig{ TCP: map[uint16]*ipn.TCPPortHandler{80: {TCPForward: "http://localhost:3000"}}, @@ -1852,7 +1853,7 @@ func TestSetServe(t *testing.T) { expectErr: true, }, { - name: "add new service handler", + name: "add-new-service-handler", desc: "add a new service TCP handler to empty config", cfg: &ipn.ServeConfig{}, @@ -1869,7 +1870,7 @@ func TestSetServe(t *testing.T) { }, }, { - name: "update service handler", + name: "update-service-handler", desc: "update an existing service TCP handler on the same port to same type", cfg: &ipn.ServeConfig{ Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ @@ -1891,7 +1892,7 @@ func TestSetServe(t *testing.T) { }, }, { - name: "update service handler", + name: "update-service-handler", desc: "update an existing service TCP handler on the same port to a http handler", cfg: &ipn.ServeConfig{ Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ @@ -1908,7 +1909,7 @@ func TestSetServe(t *testing.T) { expectErr: true, }, { - name: "add new service handler", + name: "add-new-service-handler", desc: "add a new service HTTP handler to empty config", cfg: &ipn.ServeConfig{}, dnsName: "svc:bar", @@ -1932,7 +1933,7 @@ func TestSetServe(t *testing.T) { }, }, { - name: "update existing service handler", + name: "update-existing-service-handler", desc: "update an existing service HTTP handler", cfg: &ipn.ServeConfig{ Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ @@ -1969,7 +1970,7 @@ func TestSetServe(t *testing.T) { }, }, { - name: "add new service handler", + name: "add-new-service-handler", desc: "add a new service HTTP handler to existing service config", cfg: &ipn.ServeConfig{ Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ @@ -2014,7 +2015,7 @@ func TestSetServe(t *testing.T) { }, }, { - name: "add new service mount", + name: "add-new-service-mount", desc: "add a new service mount to existing service config", cfg: &ipn.ServeConfig{ Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ @@ -2054,7 +2055,7 @@ func TestSetServe(t *testing.T) { }, }, { - name: "add new service handler", + name: "add-new-service-handler", desc: "add a new service handler in tun mode to empty config", cfg: &ipn.ServeConfig{}, dnsName: "svc:bar", @@ -2103,7 +2104,7 @@ func TestUnsetServe(t *testing.T) { expectErr bool }{ { - name: "unset http handler", + name: "unset-http-handler", desc: "remove an existing http handler", cfg: &ipn.ServeConfig{ TCP: map[uint16]*ipn.TCPPortHandler{ @@ -2128,7 +2129,7 @@ func TestUnsetServe(t *testing.T) { expectErr: false, }, { - name: "unset service handler", + name: "unset-service-handler", desc: "remove an existing service TCP handler", cfg: &ipn.ServeConfig{ Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ @@ -2157,7 +2158,7 @@ func TestUnsetServe(t *testing.T) { expectErr: false, }, { - name: "unset service handler tun", + name: "unset-service-handler-tun", desc: "remove an existing service handler in tun mode", cfg: &ipn.ServeConfig{ Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ @@ -2175,7 +2176,7 @@ func TestUnsetServe(t *testing.T) { expectErr: false, }, { - name: "unset service handler tcp", + name: "unset-service-handler-tcp", desc: "remove an existing service TCP handler", cfg: &ipn.ServeConfig{ Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ @@ -2196,7 +2197,7 @@ func TestUnsetServe(t *testing.T) { expectErr: false, }, { - name: "unset http handler not found", + name: "unset-http-handler-not-found", desc: "try to remove a non-existing http handler", cfg: &ipn.ServeConfig{ TCP: map[uint16]*ipn.TCPPortHandler{ @@ -2221,7 +2222,7 @@ func TestUnsetServe(t *testing.T) { expectErr: true, }, { - name: "unset service handler not found", + name: "unset-service-handler-not-found", desc: "try to remove a non-existing service TCP handler", cfg: &ipn.ServeConfig{ @@ -2253,7 +2254,7 @@ func TestUnsetServe(t *testing.T) { expectErr: true, }, { - name: "unset service doesn't exist", + name: "unset-service-doesnt-exist", desc: "try to remove a non-existing service's handler", cfg: &ipn.ServeConfig{ Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ @@ -2273,7 +2274,7 @@ func TestUnsetServe(t *testing.T) { expectErr: true, }, { - name: "unset tcp while port is in use", + name: "unset-tcp-while-port-in-use", desc: "try to remove a TCP handler while the port is used for web", cfg: &ipn.ServeConfig{ TCP: map[uint16]*ipn.TCPPortHandler{ @@ -2297,7 +2298,7 @@ func TestUnsetServe(t *testing.T) { expectErr: true, }, { - name: "unset service tcp while port is in use", + name: "unset-service-tcp-while-port-in-use", desc: "try to remove a service TCP handler while the port is used for web", cfg: &ipn.ServeConfig{ Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ diff --git a/cmd/tailscale/cli/set.go b/cmd/tailscale/cli/set.go index 22d78641f38a9..6fd4b09ad6790 100644 --- a/cmd/tailscale/cli/set.go +++ b/cmd/tailscale/cli/set.go @@ -24,7 +24,6 @@ import ( "tailscale.com/safesocket" "tailscale.com/tsconst" "tailscale.com/types/opt" - "tailscale.com/types/ptr" "tailscale.com/types/views" "tailscale.com/util/set" "tailscale.com/version" @@ -184,8 +183,7 @@ func runSet(ctx context.Context, args []string) (retErr error) { maskedPrefs.AutoExitNode = expr maskedPrefs.AutoExitNodeSet = true } else if err := maskedPrefs.Prefs.SetExitNodeIP(setArgs.exitNodeIP, st); err != nil { - var e ipn.ExitNodeLocalIPError - if errors.As(err, &e) { + if _, ok := errors.AsType[ipn.ExitNodeLocalIPError](err); ok { return fmt.Errorf("%w; did you mean --advertise-exit-node?", err) } return err @@ -247,13 +245,13 @@ func runSet(ctx context.Context, args []string) (retErr error) { if err != nil { return fmt.Errorf("failed to set relay server port: %v", err) } - maskedPrefs.Prefs.RelayServerPort = ptr.To(uint16(uport)) + maskedPrefs.Prefs.RelayServerPort = new(uint16(uport)) } if setArgs.relayServerStaticEndpoints != "" { endpointsSet := make(set.Set[netip.AddrPort]) - endpointsSplit := strings.Split(setArgs.relayServerStaticEndpoints, ",") - for _, s := range endpointsSplit { + endpointsSplit := strings.SplitSeq(setArgs.relayServerStaticEndpoints, ",") + for s := range endpointsSplit { ap, err := netip.ParseAddrPort(s) if err != nil { return fmt.Errorf("failed to set relay server static endpoints: %q is not a valid IP:port", s) @@ -267,6 +265,11 @@ func runSet(ctx context.Context, args []string) (retErr error) { checkPrefs := curPrefs.Clone() checkPrefs.ApplyEdits(maskedPrefs) + // We want to make sure user is aware setting --snat-subnet-routes=false with --advertise-exit-node would break exitnode, + // but we won't prevent them from doing it since there are current dependencies on that combination. (as of 2026-03-25) + if checkPrefs.NoSNAT && checkPrefs.AdvertisesExitNode() { + warnf("--snat-subnet-routes=false is set with --advertise-exit-node; internet traffic through this exit node may not work as expected") + } if err := localClient.CheckPrefs(ctx, checkPrefs); err != nil { return err } diff --git a/cmd/tailscale/cli/set_test.go b/cmd/tailscale/cli/set_test.go index 63fa3c05c48b3..e2c3ae5f64116 100644 --- a/cmd/tailscale/cli/set_test.go +++ b/cmd/tailscale/cli/set_test.go @@ -11,7 +11,6 @@ import ( "tailscale.com/ipn" "tailscale.com/net/tsaddr" - "tailscale.com/types/ptr" ) func TestCalcAdvertiseRoutesForSet(t *testing.T) { @@ -28,80 +27,80 @@ func TestCalcAdvertiseRoutesForSet(t *testing.T) { }, { name: "advertise-exit", - setExit: ptr.To(true), + setExit: new(true), want: tsaddr.ExitRoutes(), }, { name: "advertise-exit/already-routes", was: []netip.Prefix{pfx("34.0.0.0/16")}, - setExit: ptr.To(true), + setExit: new(true), want: []netip.Prefix{pfx("34.0.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, }, { name: "advertise-exit/already-exit", was: tsaddr.ExitRoutes(), - setExit: ptr.To(true), + setExit: new(true), want: tsaddr.ExitRoutes(), }, { name: "stop-advertise-exit", was: tsaddr.ExitRoutes(), - setExit: ptr.To(false), + setExit: new(false), want: nil, }, { name: "stop-advertise-exit/with-routes", was: []netip.Prefix{pfx("34.0.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, - setExit: ptr.To(false), + setExit: new(false), want: []netip.Prefix{pfx("34.0.0.0/16")}, }, { name: "advertise-routes", - setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), + setRoutes: new("10.0.0.0/24,192.168.0.0/16"), want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16")}, }, { name: "advertise-routes/already-exit", was: tsaddr.ExitRoutes(), - setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), + setRoutes: new("10.0.0.0/24,192.168.0.0/16"), want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, }, { name: "advertise-routes/already-diff-routes", was: []netip.Prefix{pfx("34.0.0.0/16")}, - setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), + setRoutes: new("10.0.0.0/24,192.168.0.0/16"), want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16")}, }, { name: "stop-advertise-routes", was: []netip.Prefix{pfx("34.0.0.0/16")}, - setRoutes: ptr.To(""), + setRoutes: new(""), want: nil, }, { name: "stop-advertise-routes/already-exit", was: []netip.Prefix{pfx("34.0.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, - setRoutes: ptr.To(""), + setRoutes: new(""), want: tsaddr.ExitRoutes(), }, { name: "advertise-routes-and-exit", - setExit: ptr.To(true), - setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), + setExit: new(true), + setRoutes: new("10.0.0.0/24,192.168.0.0/16"), want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, }, { name: "advertise-routes-and-exit/already-exit", was: tsaddr.ExitRoutes(), - setExit: ptr.To(true), - setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), + setExit: new(true), + setRoutes: new("10.0.0.0/24,192.168.0.0/16"), want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, }, { name: "advertise-routes-and-exit/already-routes", was: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16")}, - setExit: ptr.To(true), - setRoutes: ptr.To("10.0.0.0/24,192.168.0.0/16"), + setExit: new(true), + setRoutes: new("10.0.0.0/24,192.168.0.0/16"), want: []netip.Prefix{pfx("10.0.0.0/24"), pfx("192.168.0.0/16"), tsaddr.AllIPv4(), tsaddr.AllIPv6()}, }, } diff --git a/cmd/tailscale/cli/ssh.go b/cmd/tailscale/cli/ssh.go index bea18f7abf6ac..9efab8cf7e47e 100644 --- a/cmd/tailscale/cli/ssh.go +++ b/cmd/tailscale/cli/ssh.go @@ -14,6 +14,7 @@ import ( "os/user" "path/filepath" "runtime" + "slices" "strings" "github.com/peterbourgon/ff/v3/ffcli" @@ -202,10 +203,8 @@ func peerStatusFromArg(st *ipnstate.Status, arg string) (*ipnstate.PeerStatus, b argIP, _ := netip.ParseAddr(arg) for _, ps := range st.Peer { if argIP.IsValid() { - for _, ip := range ps.TailscaleIPs { - if ip == argIP { - return ps, true - } + if slices.Contains(ps.TailscaleIPs, argIP) { + return ps, true } continue } @@ -230,10 +229,8 @@ func nodeDNSNameFromArg(st *ipnstate.Status, arg string) (dnsName string, ok boo for _, ps := range st.Peer { dnsName = ps.DNSName if argIP.IsValid() { - for _, ip := range ps.TailscaleIPs { - if ip == argIP { - return dnsName, true - } + if slices.Contains(ps.TailscaleIPs, argIP) { + return dnsName, true } continue } diff --git a/cmd/tailscale/cli/ssh_exec_windows.go b/cmd/tailscale/cli/ssh_exec_windows.go index 85e1518175609..f9d306463c635 100644 --- a/cmd/tailscale/cli/ssh_exec_windows.go +++ b/cmd/tailscale/cli/ssh_exec_windows.go @@ -28,9 +28,8 @@ func execSSH(ssh string, argv []string) error { cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr - var ee *exec.ExitError err := cmd.Run() - if errors.As(err, &ee) { + if ee, ok := errors.AsType[*exec.ExitError](err); ok { os.Exit(ee.ExitCode()) } return err diff --git a/cmd/tailscale/cli/ssh_unix.go b/cmd/tailscale/cli/ssh_unix.go index 768d71116cf2c..1cc3ccbe8c66f 100644 --- a/cmd/tailscale/cli/ssh_unix.go +++ b/cmd/tailscale/cli/ssh_unix.go @@ -39,7 +39,7 @@ func init() { return "" } prefix := []byte("SSH_CLIENT=") - for _, env := range bytes.Split(b, []byte{0}) { + for env := range bytes.SplitSeq(b, []byte{0}) { if bytes.HasPrefix(env, prefix) { return string(env[len(prefix):]) } diff --git a/cmd/tailscale/cli/systray.go b/cmd/tailscale/cli/systray.go index ca0840fe9271e..07de5c7868fcf 100644 --- a/cmd/tailscale/cli/systray.go +++ b/cmd/tailscale/cli/systray.go @@ -7,6 +7,7 @@ package cli import ( "context" + "flag" "github.com/peterbourgon/ff/v3/ffcli" "tailscale.com/client/systray" @@ -17,10 +18,20 @@ var systrayCmd = &ffcli.Command{ ShortUsage: "tailscale systray", ShortHelp: "Run a systray application to manage Tailscale", LongHelp: "Run a systray application to manage Tailscale.", - Exec: runSystray, + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("systray") + fs.StringVar(&systrayArgs.theme, "theme", "dark", "color theme for Tailscale icon: dark, dark:nobg, light, light:nobg") + return fs + })(), + Exec: runSystray, +} + +var systrayArgs struct { + theme string } func runSystray(ctx context.Context, _ []string) error { + systray.SetTheme(systrayArgs.theme) new(systray.Menu).Run(&localClient) return nil } diff --git a/cmd/tailscale/cli/up.go b/cmd/tailscale/cli/up.go index 79cc60ca2347f..fed7de9ae7e01 100644 --- a/cmd/tailscale/cli/up.go +++ b/cmd/tailscale/cli/up.go @@ -113,12 +113,12 @@ func newUpFlagSet(goos string, upArgs *upArgsT, cmd string) *flag.FlagSet { upf.BoolVar(&upArgs.exitNodeAllowLANAccess, "exit-node-allow-lan-access", false, "Allow direct access to the local network when routing traffic via an exit node") upf.BoolVar(&upArgs.shieldsUp, "shields-up", false, "don't allow incoming connections") upf.BoolVar(&upArgs.runSSH, "ssh", false, "run an SSH server, permitting access per tailnet admin's declared policy") - upf.StringVar(&upArgs.advertiseTags, "advertise-tags", "", "comma-separated ACL tags to request; each must start with \"tag:\" (e.g. \"tag:eng,tag:montreal,tag:ssh\")") + upf.StringVar(&upArgs.advertiseTags, "advertise-tags", "", "comma-separated ACL tags to request (e.g. \"tag:eng,tag:montreal,tag:ssh\"); the \"tag:\" prefix is optional and added automatically when omitted (e.g. \"eng,montreal,ssh\")") upf.StringVar(&upArgs.hostname, "hostname", "", "hostname to use instead of the one provided by the OS") upf.StringVar(&upArgs.advertiseRoutes, "advertise-routes", "", "routes to advertise to other nodes (comma-separated, e.g. \"10.0.0.0/8,192.168.0.0/24\") or empty string to not advertise routes") upf.BoolVar(&upArgs.advertiseConnector, "advertise-connector", false, "advertise this node as an app connector") upf.BoolVar(&upArgs.advertiseDefaultRoute, "advertise-exit-node", false, "offer to be an exit node for internet traffic for the tailnet") - upf.BoolVar(&upArgs.postureChecking, "report-posture", false, hidden+"allow management plane to gather device posture information") + upf.BoolVar(&upArgs.postureChecking, "report-posture", false, "allow management plane to gather device posture information") if safesocket.GOOSUsesPeerCreds(goos) { upf.StringVar(&upArgs.opUser, "operator", "", "Unix username to allow to operate on tailscaled without sudo") @@ -309,9 +309,15 @@ func prefsFromUpArgs(upArgs upArgsT, warnf logger.Logf, st *ipnstate.Status, goo var tags []string if upArgs.advertiseTags != "" { tags = strings.Split(upArgs.advertiseTags, ",") - for _, tag := range tags { - err := tailcfg.CheckTag(tag) - if err != nil { + for i, tag := range tags { + // Allow users to omit the "tag:" prefix; if the tag has no + // colon at all, add it for them. Tags with a colon must be + // fully qualified ("tag:foo") and are validated as-is. + if !strings.Contains(tag, ":") { + tag = "tag:" + tag + tags[i] = tag + } + if err := tailcfg.CheckTag(tag); err != nil { return nil, fmt.Errorf("tag: %q: %s", tag, err) } } @@ -334,8 +340,7 @@ func prefsFromUpArgs(upArgs upArgsT, warnf logger.Logf, st *ipnstate.Status, goo if expr, useAutoExitNode := ipn.ParseAutoExitNodeString(upArgs.exitNodeIP); useAutoExitNode { prefs.AutoExitNode = expr } else if err := prefs.SetExitNodeIP(upArgs.exitNodeIP, st); err != nil { - var e ipn.ExitNodeLocalIPError - if errors.As(err, &e) { + if _, ok := errors.AsType[ipn.ExitNodeLocalIPError](err); ok { return nil, fmt.Errorf("%w; did you mean --advertise-exit-node?", err) } return nil, err @@ -358,6 +363,11 @@ func prefsFromUpArgs(upArgs upArgsT, warnf logger.Logf, st *ipnstate.Status, goo if goos == "linux" { prefs.NoSNAT = !upArgs.snat + // We want to make sure user is aware setting --snat-subnet-routes=false with --advertise-exit-node would break exitnode, + // but we won't prevent them from doing it since there are current dependencies on that combination. (as of 2026-03-25) + if prefs.NoSNAT && prefs.AdvertisesExitNode() { + warnf("--snat-subnet-routes=false is set with --advertise-exit-node; internet traffic through this exit node may not work as expected") + } // Backfills for NoStatefulFiltering occur when loading a profile; just set it explicitly here. prefs.NoStatefulFiltering.Set(!upArgs.statefulFiltering) @@ -595,6 +605,11 @@ func runUp(ctx context.Context, cmd string, args []string, upArgs upArgsT) (retE } }() + if !buildfeatures.HasIPNBus { + fmt.Fprintln(Stderr, "binary built with ts_omit_ipnbus; not waiting for completion") + return nil + } + // Start watching the IPN bus before we call Start() or StartLoginInteractive(), // or we could miss IPN notifications. // @@ -717,7 +732,7 @@ func runUp(ctx context.Context, cmd string, args []string, upArgs upArgsT) (retE if s := n.State; s != nil { ipnIsRunning = *s == ipn.Running } - if n.NetMap != nil && n.NetMap.NodeKey != origNodeKey { + if n.SelfChange != nil && n.SelfChange.Key != origNodeKey { waitingForKeyChange = false } if ipnIsRunning && !waitingForKeyChange { @@ -912,7 +927,7 @@ func addPrefFlagMapping(flagName string, prefNames ...string) { prefType := reflect.TypeFor[ipn.Prefs]() for _, pref := range prefNames { t := prefType - for _, name := range strings.Split(pref, ".") { + for name := range strings.SplitSeq(pref, ".") { // Crash at runtime if there's a typo in the prefName. f, ok := t.FieldByName(name) if !ok { diff --git a/cmd/tailscale/cli/whois.go b/cmd/tailscale/cli/whois.go index b2ad74149635b..7cc8f2889f5b3 100644 --- a/cmd/tailscale/cli/whois.go +++ b/cmd/tailscale/cli/whois.go @@ -26,7 +26,7 @@ var whoisCmd = &ffcli.Command{ FlagSet: func() *flag.FlagSet { fs := newFlagSet("whois") fs.BoolVar(&whoIsArgs.json, "json", false, "output in JSON format") - fs.StringVar(&whoIsArgs.proto, "proto", "", `protocol; one of "tcp" or "udp"; empty mans both `) + fs.StringVar(&whoIsArgs.proto, "proto", "", `protocol; one of "tcp" or "udp"; empty means both`) return fs }(), } diff --git a/cmd/tailscale/depaware.txt b/cmd/tailscale/depaware.txt index b4605f9f2e926..d23ab1f4f658a 100644 --- a/cmd/tailscale/depaware.txt +++ b/cmd/tailscale/depaware.txt @@ -239,7 +239,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep tailscale.com/tstime from tailscale.com/control/controlhttp+ tailscale.com/tstime/mono from tailscale.com/tstime/rate tailscale.com/tstime/rate from tailscale.com/cmd/tailscale/cli - tailscale.com/tsweb from tailscale.com/util/eventbus + tailscale.com/tsweb from tailscale.com/util/eventbus+ tailscale.com/tsweb/varz from tailscale.com/util/usermetric+ tailscale.com/types/appctype from tailscale.com/client/local+ tailscale.com/types/dnstype from tailscale.com/tailcfg+ @@ -253,12 +253,12 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep tailscale.com/types/opt from tailscale.com/client/tailscale+ tailscale.com/types/persist from tailscale.com/ipn+ tailscale.com/types/preftype from tailscale.com/cmd/tailscale/cli+ - tailscale.com/types/ptr from tailscale.com/hostinfo+ tailscale.com/types/result from tailscale.com/util/lineiter tailscale.com/types/structs from tailscale.com/ipn+ tailscale.com/types/tkatype from tailscale.com/types/key+ tailscale.com/types/views from tailscale.com/tailcfg+ tailscale.com/util/backoff from tailscale.com/cmd/tailscale/cli + tailscale.com/util/bufiox from tailscale.com/types/key tailscale.com/util/cibuild from tailscale.com/health+ tailscale.com/util/clientmetric from tailscale.com/net/netcheck+ tailscale.com/util/cloudenv from tailscale.com/net/dnscache+ @@ -331,7 +331,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep golang.org/x/net/icmp from tailscale.com/net/ping golang.org/x/net/idna from golang.org/x/net/http/httpproxy+ golang.org/x/net/internal/iana from golang.org/x/net/icmp+ - golang.org/x/net/internal/socket from golang.org/x/net/icmp+ + golang.org/x/net/internal/socket from golang.org/x/net/ipv4+ golang.org/x/net/internal/socks from golang.org/x/net/proxy golang.org/x/net/ipv4 from golang.org/x/net/icmp+ golang.org/x/net/ipv6 from golang.org/x/net/icmp+ diff --git a/cmd/tailscaled/depaware-min.txt b/cmd/tailscaled/depaware-min.txt index 2ad5cbca7b3af..8f0c34cf179f8 100644 --- a/cmd/tailscaled/depaware-min.txt +++ b/cmd/tailscaled/depaware-min.txt @@ -70,7 +70,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/ipn from tailscale.com/cmd/tailscaled+ tailscale.com/ipn/conffile from tailscale.com/cmd/tailscaled+ tailscale.com/ipn/ipnauth from tailscale.com/ipn/ipnext+ - tailscale.com/ipn/ipnext from tailscale.com/ipn/ipnlocal + tailscale.com/ipn/ipnext from tailscale.com/ipn/ipnlocal+ tailscale.com/ipn/ipnlocal from tailscale.com/cmd/tailscaled+ tailscale.com/ipn/ipnlocal/netmapcache from tailscale.com/ipn/ipnlocal tailscale.com/ipn/ipnserver from tailscale.com/cmd/tailscaled @@ -101,7 +101,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/net/netknob from tailscale.com/logpolicy+ tailscale.com/net/netmon from tailscale.com/cmd/tailscaled+ tailscale.com/net/netns from tailscale.com/cmd/tailscaled+ - tailscale.com/net/netutil from tailscale.com/control/controlclient+ + tailscale.com/net/netutil from tailscale.com/control/controlhttp+ tailscale.com/net/netx from tailscale.com/control/controlclient+ tailscale.com/net/packet from tailscale.com/ipn/ipnlocal+ tailscale.com/net/packet/checksum from tailscale.com/net/tstun @@ -132,6 +132,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/types/appctype from tailscale.com/ipn/ipnlocal+ tailscale.com/types/dnstype from tailscale.com/client/tailscale/apitype+ tailscale.com/types/empty from tailscale.com/ipn+ + tailscale.com/types/events from tailscale.com/control/controlclient+ tailscale.com/types/flagtype from tailscale.com/cmd/tailscaled tailscale.com/types/ipproto from tailscale.com/ipn+ tailscale.com/types/key from tailscale.com/control/controlbase+ @@ -145,12 +146,12 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/types/opt from tailscale.com/control/controlknobs+ tailscale.com/types/persist from tailscale.com/control/controlclient+ tailscale.com/types/preftype from tailscale.com/ipn+ - tailscale.com/types/ptr from tailscale.com/control/controlclient+ tailscale.com/types/result from tailscale.com/util/lineiter tailscale.com/types/structs from tailscale.com/control/controlclient+ tailscale.com/types/tkatype from tailscale.com/control/controlclient+ tailscale.com/types/views from tailscale.com/appc+ tailscale.com/util/backoff from tailscale.com/control/controlclient+ + tailscale.com/util/bufiox from tailscale.com/types/key tailscale.com/util/checkchange from tailscale.com/ipn/ipnlocal+ tailscale.com/util/cibuild from tailscale.com/health+ tailscale.com/util/clientmetric from tailscale.com/appc+ @@ -218,7 +219,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de golang.org/x/net/icmp from tailscale.com/net/ping golang.org/x/net/idna from golang.org/x/net/http/httpguts golang.org/x/net/internal/iana from golang.org/x/net/icmp+ - golang.org/x/net/internal/socket from golang.org/x/net/icmp+ + golang.org/x/net/internal/socket from golang.org/x/net/ipv4+ golang.org/x/net/ipv4 from github.com/tailscale/wireguard-go/conn+ golang.org/x/net/ipv6 from github.com/tailscale/wireguard-go/conn+ golang.org/x/sync/errgroup from github.com/mdlayher/socket diff --git a/cmd/tailscaled/depaware-minbox.txt b/cmd/tailscaled/depaware-minbox.txt index 64911d9318f03..994310d60ab86 100644 --- a/cmd/tailscaled/depaware-minbox.txt +++ b/cmd/tailscaled/depaware-minbox.txt @@ -85,7 +85,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/ipn from tailscale.com/cmd/tailscaled+ tailscale.com/ipn/conffile from tailscale.com/cmd/tailscaled+ tailscale.com/ipn/ipnauth from tailscale.com/ipn/ipnext+ - tailscale.com/ipn/ipnext from tailscale.com/ipn/ipnlocal + tailscale.com/ipn/ipnext from tailscale.com/ipn/ipnlocal+ tailscale.com/ipn/ipnlocal from tailscale.com/cmd/tailscaled+ tailscale.com/ipn/ipnlocal/netmapcache from tailscale.com/ipn/ipnlocal tailscale.com/ipn/ipnserver from tailscale.com/cmd/tailscaled @@ -118,7 +118,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/net/netknob from tailscale.com/logpolicy+ tailscale.com/net/netmon from tailscale.com/cmd/tailscaled+ tailscale.com/net/netns from tailscale.com/cmd/tailscaled+ - tailscale.com/net/netutil from tailscale.com/control/controlclient+ + tailscale.com/net/netutil from tailscale.com/client/local+ tailscale.com/net/netx from tailscale.com/control/controlclient+ tailscale.com/net/packet from tailscale.com/ipn/ipnlocal+ tailscale.com/net/packet/checksum from tailscale.com/net/tstun @@ -151,6 +151,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/types/appctype from tailscale.com/ipn/ipnlocal+ tailscale.com/types/dnstype from tailscale.com/client/tailscale/apitype+ tailscale.com/types/empty from tailscale.com/ipn+ + tailscale.com/types/events from tailscale.com/control/controlclient+ tailscale.com/types/flagtype from tailscale.com/cmd/tailscaled tailscale.com/types/ipproto from tailscale.com/ipn+ tailscale.com/types/key from tailscale.com/client/local+ @@ -164,12 +165,12 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/types/opt from tailscale.com/control/controlknobs+ tailscale.com/types/persist from tailscale.com/control/controlclient+ tailscale.com/types/preftype from tailscale.com/ipn+ - tailscale.com/types/ptr from tailscale.com/control/controlclient+ tailscale.com/types/result from tailscale.com/util/lineiter tailscale.com/types/structs from tailscale.com/control/controlclient+ tailscale.com/types/tkatype from tailscale.com/control/controlclient+ tailscale.com/types/views from tailscale.com/appc+ tailscale.com/util/backoff from tailscale.com/control/controlclient+ + tailscale.com/util/bufiox from tailscale.com/types/key tailscale.com/util/checkchange from tailscale.com/ipn/ipnlocal+ tailscale.com/util/cibuild from tailscale.com/health+ tailscale.com/util/clientmetric from tailscale.com/appc+ @@ -239,7 +240,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de golang.org/x/net/icmp from tailscale.com/net/ping golang.org/x/net/idna from golang.org/x/net/http/httpguts+ golang.org/x/net/internal/iana from golang.org/x/net/icmp+ - golang.org/x/net/internal/socket from golang.org/x/net/icmp+ + golang.org/x/net/internal/socket from golang.org/x/net/ipv4+ golang.org/x/net/ipv4 from github.com/tailscale/wireguard-go/conn+ golang.org/x/net/ipv6 from github.com/tailscale/wireguard-go/conn+ golang.org/x/sync/errgroup from github.com/mdlayher/socket diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index 207d86243b607..57332f5f8928f 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -6,7 +6,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de W đŸ’Ŗ github.com/alexbrainman/sspi from github.com/alexbrainman/sspi/internal/common+ W github.com/alexbrainman/sspi/internal/common from github.com/alexbrainman/sspi/negotiate W đŸ’Ŗ github.com/alexbrainman/sspi/negotiate from tailscale.com/net/tshttpproxy - LD github.com/anmitsu/go-shlex from tailscale.com/tempfork/gliderlabs/ssh + LD github.com/anmitsu/go-shlex from github.com/tailscale/gliderssh L github.com/aws/aws-sdk-go-v2/aws from github.com/aws/aws-sdk-go-v2/aws/defaults+ L github.com/aws/aws-sdk-go-v2/aws/arn from tailscale.com/ipn/store/awsstore L github.com/aws/aws-sdk-go-v2/aws/defaults from github.com/aws/aws-sdk-go-v2/service/ssm+ @@ -130,7 +130,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de L github.com/google/nftables/expr from github.com/google/nftables+ L github.com/google/nftables/internal/parseexprfunc from github.com/google/nftables+ L github.com/google/nftables/xt from github.com/google/nftables/expr+ - DW github.com/google/uuid from tailscale.com/clientupdate+ + W github.com/google/uuid from tailscale.com/clientupdate github.com/hdevalence/ed25519consensus from tailscale.com/clientupdate/distsign+ github.com/huin/goupnp from github.com/huin/goupnp/dcps/internetgateway2+ github.com/huin/goupnp/dcps/internetgateway2 from tailscale.com/net/portmapper @@ -173,9 +173,9 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de github.com/pires/go-proxyproto from tailscale.com/ipn/ipnlocal LD github.com/pkg/sftp from tailscale.com/ssh/tailssh LD github.com/pkg/sftp/internal/encoding/ssh/filexfer from github.com/pkg/sftp - D github.com/prometheus-community/pro-bing from tailscale.com/wgengine/netstack L đŸ’Ŗ github.com/safchain/ethtool from tailscale.com/net/netkernelconf+ - W đŸ’Ŗ github.com/tailscale/certstore from tailscale.com/control/controlclient + DW đŸ’Ŗ github.com/tailscale/certstore from tailscale.com/control/controlclient + LD github.com/tailscale/gliderssh from tailscale.com/ssh/tailssh W đŸ’Ŗ github.com/tailscale/go-winio from tailscale.com/safesocket W đŸ’Ŗ github.com/tailscale/go-winio/internal/fs from github.com/tailscale/go-winio W đŸ’Ŗ github.com/tailscale/go-winio/internal/socket from github.com/tailscale/go-winio @@ -250,7 +250,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de gvisor.dev/gvisor/pkg/tcpip/transport/udp from gvisor.dev/gvisor/pkg/tcpip/adapters/gonet+ gvisor.dev/gvisor/pkg/waiter from gvisor.dev/gvisor/pkg/context+ tailscale.com from tailscale.com/version - tailscale.com/appc from tailscale.com/ipn/ipnlocal + tailscale.com/appc from tailscale.com/ipn/ipnlocal+ đŸ’Ŗ tailscale.com/atomicfile from tailscale.com/ipn+ LD tailscale.com/chirp from tailscale.com/cmd/tailscaled tailscale.com/client/local from tailscale.com/client/web+ @@ -258,6 +258,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/client/web from tailscale.com/ipn/ipnlocal tailscale.com/clientupdate from tailscale.com/feature/clientupdate LW tailscale.com/clientupdate/distsign from tailscale.com/clientupdate + tailscale.com/cmd/tailscale/cli/jsonoutput from tailscale.com/feature/tailnetlock tailscale.com/cmd/tailscaled/childproc from tailscale.com/cmd/tailscaled+ tailscale.com/cmd/tailscaled/tailscaledhooks from tailscale.com/cmd/tailscaled+ tailscale.com/control/controlbase from tailscale.com/control/controlhttp+ @@ -303,10 +304,13 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/feature/posture from tailscale.com/feature/condregister tailscale.com/feature/relayserver from tailscale.com/feature/condregister L tailscale.com/feature/sdnotify from tailscale.com/feature/condregister + LD tailscale.com/feature/ssh from tailscale.com/cmd/tailscaled tailscale.com/feature/syspolicy from tailscale.com/feature/condregister+ tailscale.com/feature/taildrop from tailscale.com/feature/condregister + tailscale.com/feature/tailnetlock from tailscale.com/feature/condregister L tailscale.com/feature/tap from tailscale.com/feature/condregister tailscale.com/feature/tpm from tailscale.com/feature/condregister + L đŸ’Ŗ tailscale.com/feature/tundevstats from tailscale.com/feature/condregister tailscale.com/feature/useproxy from tailscale.com/feature/condregister/useproxy tailscale.com/feature/wakeonlan from tailscale.com/feature/condregister tailscale.com/health from tailscale.com/control/controlclient+ @@ -315,7 +319,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/ipn from tailscale.com/client/local+ W tailscale.com/ipn/auditlog from tailscale.com/cmd/tailscaled tailscale.com/ipn/conffile from tailscale.com/cmd/tailscaled+ - W đŸ’Ŗ tailscale.com/ipn/desktop from tailscale.com/cmd/tailscaled + W đŸ’Ŗ tailscale.com/ipn/desktop from tailscale.com/feature/condregister đŸ’Ŗ tailscale.com/ipn/ipnauth from tailscale.com/ipn/ipnlocal+ tailscale.com/ipn/ipnext from tailscale.com/ipn/auditlog+ tailscale.com/ipn/ipnlocal from tailscale.com/cmd/tailscaled+ @@ -361,7 +365,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/net/netutil from tailscale.com/client/local+ tailscale.com/net/netx from tailscale.com/control/controlclient+ tailscale.com/net/packet from tailscale.com/feature/capture+ - tailscale.com/net/packet/checksum from tailscale.com/net/tstun + tailscale.com/net/packet/checksum from tailscale.com/net/tstun+ tailscale.com/net/ping from tailscale.com/net/netcheck+ tailscale.com/net/portmapper from tailscale.com/feature/portmapper+ tailscale.com/net/portmapper/portmappertype from tailscale.com/feature/portmapper+ @@ -387,11 +391,10 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/proxymap from tailscale.com/tsd+ đŸ’Ŗ tailscale.com/safesocket from tailscale.com/client/local+ LD tailscale.com/sessionrecording from tailscale.com/ssh/tailssh - LD đŸ’Ŗ tailscale.com/ssh/tailssh from tailscale.com/cmd/tailscaled + LD đŸ’Ŗ tailscale.com/ssh/tailssh from tailscale.com/feature/ssh tailscale.com/syncs from tailscale.com/cmd/tailscaled+ tailscale.com/tailcfg from tailscale.com/client/local+ tailscale.com/tempfork/acme from tailscale.com/ipn/ipnlocal - LD tailscale.com/tempfork/gliderlabs/ssh from tailscale.com/ssh/tailssh tailscale.com/tempfork/heap from tailscale.com/wgengine/magicsock tailscale.com/tempfork/httprec from tailscale.com/feature/c2n tailscale.com/tka from tailscale.com/client/local+ @@ -400,12 +403,13 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/tstime from tailscale.com/control/controlclient+ tailscale.com/tstime/mono from tailscale.com/net/tstun+ tailscale.com/tstime/rate from tailscale.com/wgengine/filter - tailscale.com/tsweb from tailscale.com/util/eventbus + tailscale.com/tsweb from tailscale.com/util/eventbus+ tailscale.com/tsweb/varz from tailscale.com/cmd/tailscaled+ tailscale.com/types/appctype from tailscale.com/ipn/ipnlocal+ tailscale.com/types/bools from tailscale.com/wgengine/netlog tailscale.com/types/dnstype from tailscale.com/ipn/ipnlocal+ tailscale.com/types/empty from tailscale.com/ipn+ + tailscale.com/types/events from tailscale.com/control/controlclient+ tailscale.com/types/flagtype from tailscale.com/cmd/tailscaled tailscale.com/types/ipproto from tailscale.com/net/flowtrack+ tailscale.com/types/key from tailscale.com/client/local+ @@ -420,12 +424,12 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/types/opt from tailscale.com/control/controlknobs+ tailscale.com/types/persist from tailscale.com/control/controlclient+ tailscale.com/types/preftype from tailscale.com/ipn+ - tailscale.com/types/ptr from tailscale.com/control/controlclient+ tailscale.com/types/result from tailscale.com/util/lineiter tailscale.com/types/structs from tailscale.com/control/controlclient+ tailscale.com/types/tkatype from tailscale.com/tka+ tailscale.com/types/views from tailscale.com/ipn/ipnlocal+ tailscale.com/util/backoff from tailscale.com/cmd/tailscaled+ + tailscale.com/util/bufiox from tailscale.com/types/key tailscale.com/util/checkchange from tailscale.com/ipn/ipnlocal+ tailscale.com/util/cibuild from tailscale.com/health+ tailscale.com/util/clientmetric from tailscale.com/control/controlclient+ @@ -522,13 +526,13 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de golang.org/x/net/dns/dnsmessage from tailscale.com/appc+ golang.org/x/net/http/httpguts from tailscale.com/ipn/ipnlocal golang.org/x/net/http/httpproxy from tailscale.com/net/tshttpproxy - golang.org/x/net/icmp from tailscale.com/net/ping+ + golang.org/x/net/icmp from tailscale.com/net/ping golang.org/x/net/idna from golang.org/x/net/http/httpguts+ golang.org/x/net/internal/iana from golang.org/x/net/icmp+ - golang.org/x/net/internal/socket from golang.org/x/net/icmp+ + golang.org/x/net/internal/socket from golang.org/x/net/ipv4+ golang.org/x/net/internal/socks from golang.org/x/net/proxy - golang.org/x/net/ipv4 from github.com/prometheus-community/pro-bing+ - golang.org/x/net/ipv6 from github.com/prometheus-community/pro-bing+ + golang.org/x/net/ipv4 from github.com/tailscale/wireguard-go/conn+ + golang.org/x/net/ipv6 from github.com/tailscale/wireguard-go/conn+ golang.org/x/net/proxy from tailscale.com/net/netns D golang.org/x/net/route from tailscale.com/net/netmon+ golang.org/x/sync/errgroup from github.com/mdlayher/socket+ @@ -589,22 +593,22 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de crypto/internal/boring/bbig from crypto/ecdsa+ crypto/internal/boring/sig from crypto/internal/boring crypto/internal/constanttime from crypto/internal/fips140/edwards25519+ - crypto/internal/fips140 from crypto/internal/fips140/aes+ + crypto/internal/fips140 from crypto/fips140+ crypto/internal/fips140/aes from crypto/aes+ crypto/internal/fips140/aes/gcm from crypto/cipher+ crypto/internal/fips140/alias from crypto/cipher+ crypto/internal/fips140/bigmod from crypto/internal/fips140/ecdsa+ - crypto/internal/fips140/check from crypto/internal/fips140/aes+ - crypto/internal/fips140/drbg from crypto/internal/fips140/aes/gcm+ + crypto/internal/fips140/check from crypto/fips140+ + crypto/internal/fips140/drbg from crypto/hpke+ crypto/internal/fips140/ecdh from crypto/ecdh crypto/internal/fips140/ecdsa from crypto/ecdsa crypto/internal/fips140/ed25519 from crypto/ed25519 crypto/internal/fips140/edwards25519 from crypto/internal/fips140/ed25519 crypto/internal/fips140/edwards25519/field from crypto/ecdh+ - crypto/internal/fips140/hkdf from crypto/internal/fips140/tls13+ + crypto/internal/fips140/hkdf from crypto/hkdf+ crypto/internal/fips140/hmac from crypto/hmac+ crypto/internal/fips140/mlkem from crypto/mlkem - crypto/internal/fips140/nistec from crypto/elliptic+ + crypto/internal/fips140/nistec from crypto/ecdsa+ crypto/internal/fips140/nistec/fiat from crypto/internal/fips140/nistec crypto/internal/fips140/rsa from crypto/rsa crypto/internal/fips140/sha256 from crypto/internal/fips140/check+ @@ -639,7 +643,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de crypto/x509 from crypto/tls+ D crypto/x509/internal/macos from crypto/x509 crypto/x509/pkix from crypto/x509+ - DW database/sql/driver from github.com/google/uuid + W database/sql/driver from github.com/google/uuid W debug/dwarf from debug/pe W debug/pe from github.com/dblohm7/wingoes/pe embed from github.com/tailscale/web-client-prebuilt+ @@ -679,7 +683,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de internal/goos from crypto/x509+ internal/msan from internal/runtime/maps+ internal/nettrace from net+ - internal/oserror from io/fs+ + internal/oserror from internal/syscall/windows+ internal/poll from net+ internal/profile from net/http/pprof internal/profilerecord from runtime+ @@ -689,9 +693,9 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de internal/runtime/atomic from internal/runtime/exithook+ L internal/runtime/cgroup from runtime internal/runtime/exithook from runtime - internal/runtime/gc from runtime+ + internal/runtime/gc from internal/runtime/gc/scan+ internal/runtime/gc/scan from runtime - internal/runtime/maps from reflect+ + internal/runtime/maps from hash/maphash+ internal/runtime/math from internal/runtime/maps+ internal/runtime/pprof/label from runtime+ internal/runtime/sys from crypto/subtle+ @@ -705,7 +709,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de internal/synctest from sync internal/syscall/execenv from os+ LD internal/syscall/unix from crypto/internal/sysrand+ - W internal/syscall/windows from crypto/internal/sysrand+ + W internal/syscall/windows from crypto/internal/fips140deps/time+ W internal/syscall/windows/registry from mime+ W internal/syscall/windows/sysdll from internal/syscall/windows+ internal/testlog from os @@ -729,7 +733,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de mime/quotedprintable from mime/multipart net from crypto/tls+ net/http from expvar+ - net/http/httptrace from github.com/prometheus-community/pro-bing+ + net/http/httptrace from github.com/aws/smithy-go/transport/http+ net/http/httputil from github.com/aws/smithy-go/transport/http+ net/http/internal from net/http+ net/http/internal/ascii from net/http+ diff --git a/cmd/tailscaled/ssh.go b/cmd/tailscaled/ssh.go index e69cbd5dce086..8de3117944431 100644 --- a/cmd/tailscaled/ssh.go +++ b/cmd/tailscaled/ssh.go @@ -5,5 +5,5 @@ package main -// Force registration of tailssh with LocalBackend. -import _ "tailscale.com/ssh/tailssh" +// Register implementations of various SSH hooks. +import _ "tailscale.com/feature/ssh" diff --git a/cmd/tailscaled/tailscaled.go b/cmd/tailscaled/tailscaled.go index df0d68e077b2b..fe18731ae09a1 100644 --- a/cmd/tailscaled/tailscaled.go +++ b/cmd/tailscaled/tailscaled.go @@ -744,6 +744,7 @@ func tryEngine(logf logger.Logf, sys *tsd.System, name string) (onlyNetstack boo ListenPort: args.port, NetMon: sys.NetMon.Get(), HealthTracker: sys.HealthTracker.Get(), + ExtraRootCAs: sys.ExtraRootCAs, Metrics: sys.UserMetricsRegistry(), Dialer: sys.Dialer.Get(), SetSubsystem: sys.Set, diff --git a/cmd/tailscaled/tailscaled_test.go b/cmd/tailscaled/tailscaled_test.go index 7d76e7683a623..ab6482293687b 100644 --- a/cmd/tailscaled/tailscaled_test.go +++ b/cmd/tailscaled/tailscaled_test.go @@ -58,7 +58,7 @@ func TestStateStoreError(t *testing.T) { args.statedir = t.TempDir() args.tunname = "userspace-networking" - t.Run("new state", func(t *testing.T) { + t.Run("new-state", func(t *testing.T) { sys := tsd.NewSystem() sys.NetMon.Set(must.Get(netmon.New(sys.Bus.Get(), t.Logf))) lb, err := getLocalBackend(t.Context(), t.Logf, logID.Public(), sys) @@ -70,7 +70,7 @@ func TestStateStoreError(t *testing.T) { t.Errorf("StateStoreHealth is unhealthy on fresh LocalBackend:\n%s", strings.Join(lb.HealthTracker().Strings(), "\n")) } }) - t.Run("corrupt state", func(t *testing.T) { + t.Run("corrupt-state", func(t *testing.T) { sys := tsd.NewSystem() sys.NetMon.Set(must.Get(netmon.New(sys.Bus.Get(), t.Logf))) // Populate the state file with something that will fail to parse to diff --git a/cmd/tailscaled/tailscaled_windows.go b/cmd/tailscaled/tailscaled_windows.go index 63c8b30c99348..0ad550d4cc0cd 100644 --- a/cmd/tailscaled/tailscaled_windows.go +++ b/cmd/tailscaled/tailscaled_windows.go @@ -45,7 +45,6 @@ import ( "tailscale.com/drive/driveimpl" "tailscale.com/envknob" _ "tailscale.com/ipn/auditlog" - _ "tailscale.com/ipn/desktop" "tailscale.com/logpolicy" "tailscale.com/net/dns" "tailscale.com/net/netmon" diff --git a/cmd/testwrapper/testwrapper.go b/cmd/testwrapper/testwrapper.go index e35b83407bbb8..204409a630383 100644 --- a/cmd/testwrapper/testwrapper.go +++ b/cmd/testwrapper/testwrapper.go @@ -352,8 +352,7 @@ func main() { // If there's nothing to retry and no non-retryable tests have // failed then we've probably hit a build error. if err := <-runErr; len(toRetry) == 0 && err != nil { - var exit *exec.ExitError - if errors.As(err, &exit) { + if exit, ok := errors.AsType[*exec.ExitError](err); ok { if code := exit.ExitCode(); code > -1 { os.Exit(exit.ExitCode()) } diff --git a/cmd/testwrapper/testwrapper_test.go b/cmd/testwrapper/testwrapper_test.go index cf023f4367483..46400fd1c0a67 100644 --- a/cmd/testwrapper/testwrapper_test.go +++ b/cmd/testwrapper/testwrapper_test.go @@ -220,11 +220,14 @@ func TestCached(t *testing.T) { // Construct our trivial package. pkgDir := t.TempDir() + goVersion := runtime.Version() + goVersion = strings.TrimPrefix(goVersion, "go") + goVersion, _, _ = strings.Cut(goVersion, "-X:") // map 1.26.1-X:nogreenteagc to 1.26.1 + goMod := fmt.Sprintf(`module example.com go %s -`, runtime.Version()[2:]) // strip leading "go" - +`, goVersion) test := `package main import "testing" @@ -273,8 +276,7 @@ func TestCached(t *testing.T) {} } func errExitCode(err error) (int, bool) { - var exit *exec.ExitError - if errors.As(err, &exit) { + if exit, ok := errors.AsType[*exec.ExitError](err); ok { return exit.ExitCode(), true } return 0, false diff --git a/cmd/tsconnect/wasm/wasm_js.go b/cmd/tsconnect/wasm/wasm_js.go index 8a0177d1d66f7..f58e4201ab83c 100644 --- a/cmd/tsconnect/wasm/wasm_js.go +++ b/cmd/tsconnect/wasm/wasm_js.go @@ -110,6 +110,7 @@ func newIPN(jsConfig js.Value) map[string]any { SetSubsystem: sys.Set, ControlKnobs: sys.ControlKnobs(), HealthTracker: sys.HealthTracker.Get(), + ExtraRootCAs: sys.ExtraRootCAs, Metrics: sys.UserMetricsRegistry(), EventBus: sys.Bus.Get(), }) @@ -257,44 +258,50 @@ func (i *jsIPN) run(jsCallbacks js.Value) { if n.State != nil { notifyState(*n.State) } - if nm := n.NetMap; nm != nil { - jsNetMap := jsNetMap{ - Self: jsNetMapSelfNode{ - jsNetMapNode: jsNetMapNode{ - Name: nm.SelfName(), - Addresses: mapSliceView(nm.GetAddresses(), func(a netip.Prefix) string { return a.Addr().String() }), - NodeKey: nm.NodeKey.String(), - MachineKey: nm.MachineKey.String(), - }, - MachineStatus: jsMachineStatus[nm.GetMachineStatus()], - }, - Peers: mapSlice(nm.Peers, func(p tailcfg.NodeView) jsNetMapPeerNode { - name := p.Name() - if name == "" { - // In practice this should only happen for Hello. - name = p.Hostinfo().Hostname() - } - addrs := make([]string, p.Addresses().Len()) - for i, ap := range p.Addresses().All() { - addrs[i] = ap.Addr().String() - } - return jsNetMapPeerNode{ + if n.SelfChange != nil { + // Self changed: rebuild the JS-side NetMap snapshot. Peers + // don't ride on the bus anymore, so fetch them on demand + // from LocalBackend. + nm := i.lb.NetMapWithPeers() + if nm != nil { + jsNetMap := jsNetMap{ + Self: jsNetMapSelfNode{ jsNetMapNode: jsNetMapNode{ - Name: name, - Addresses: addrs, - MachineKey: p.Machine().String(), - NodeKey: p.Key().String(), + Name: nm.SelfName(), + Addresses: mapSliceView(nm.GetAddresses(), func(a netip.Prefix) string { return a.Addr().String() }), + NodeKey: nm.NodeKey.String(), + MachineKey: nm.MachineKey.String(), }, - Online: p.Online().Clone(), - TailscaleSSHEnabled: p.Hostinfo().TailscaleSSHEnabled(), - } - }), - LockedOut: nm.TKAEnabled && nm.SelfNode.KeySignature().Len() == 0, - } - if jsonNetMap, err := json.Marshal(jsNetMap); err == nil { - jsCallbacks.Call("notifyNetMap", string(jsonNetMap)) - } else { - log.Printf("Could not generate JSON netmap: %v", err) + MachineStatus: jsMachineStatus[nm.GetMachineStatus()], + }, + Peers: mapSlice(nm.Peers, func(p tailcfg.NodeView) jsNetMapPeerNode { + name := p.Name() + if name == "" { + // In practice this should only happen for Hello. + name = p.Hostinfo().Hostname() + } + addrs := make([]string, p.Addresses().Len()) + for i, ap := range p.Addresses().All() { + addrs[i] = ap.Addr().String() + } + return jsNetMapPeerNode{ + jsNetMapNode: jsNetMapNode{ + Name: name, + Addresses: addrs, + MachineKey: p.Machine().String(), + NodeKey: p.Key().String(), + }, + Online: p.Online().Clone(), + TailscaleSSHEnabled: p.Hostinfo().TailscaleSSHEnabled(), + } + }), + LockedOut: nm.TKAEnabled && nm.SelfNode.KeySignature().Len() == 0, + } + if jsonNetMap, err := json.Marshal(jsNetMap); err == nil { + jsCallbacks.Call("notifyNetMap", string(jsonNetMap)) + } else { + log.Printf("Could not generate JSON netmap: %v", err) + } } } if n.BrowseToURL != nil { diff --git a/cmd/tsidp/depaware.txt b/cmd/tsidp/depaware.txt index bb991383c8a06..cf1a4c279c865 100644 --- a/cmd/tsidp/depaware.txt +++ b/cmd/tsidp/depaware.txt @@ -6,77 +6,6 @@ tailscale.com/cmd/tsidp dependencies: (generated by github.com/tailscale/depawar W đŸ’Ŗ github.com/alexbrainman/sspi from github.com/alexbrainman/sspi/internal/common+ W github.com/alexbrainman/sspi/internal/common from github.com/alexbrainman/sspi/negotiate W đŸ’Ŗ github.com/alexbrainman/sspi/negotiate from tailscale.com/net/tshttpproxy - github.com/aws/aws-sdk-go-v2/aws from github.com/aws/aws-sdk-go-v2/aws/defaults+ - github.com/aws/aws-sdk-go-v2/aws/defaults from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/aws/middleware from github.com/aws/aws-sdk-go-v2/aws/retry+ - github.com/aws/aws-sdk-go-v2/aws/protocol/query from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/aws-sdk-go-v2/aws/protocol/restjson from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/aws/protocol/xml from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/aws-sdk-go-v2/aws/ratelimit from github.com/aws/aws-sdk-go-v2/aws/retry - github.com/aws/aws-sdk-go-v2/aws/retry from github.com/aws/aws-sdk-go-v2/credentials/endpointcreds/internal/client+ - github.com/aws/aws-sdk-go-v2/aws/signer/internal/v4 from github.com/aws/aws-sdk-go-v2/aws/signer/v4 - github.com/aws/aws-sdk-go-v2/aws/signer/v4 from github.com/aws/aws-sdk-go-v2/internal/auth/smithy+ - github.com/aws/aws-sdk-go-v2/aws/transport/http from github.com/aws/aws-sdk-go-v2/config+ - github.com/aws/aws-sdk-go-v2/config from tailscale.com/wif - github.com/aws/aws-sdk-go-v2/credentials from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/credentials/ec2rolecreds from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/credentials/endpointcreds from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/credentials/endpointcreds/internal/client from github.com/aws/aws-sdk-go-v2/credentials/endpointcreds - github.com/aws/aws-sdk-go-v2/credentials/processcreds from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/credentials/ssocreds from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/credentials/stscreds from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/feature/ec2/imds from github.com/aws/aws-sdk-go-v2/config+ - github.com/aws/aws-sdk-go-v2/feature/ec2/imds/internal/config from github.com/aws/aws-sdk-go-v2/feature/ec2/imds - github.com/aws/aws-sdk-go-v2/internal/auth from github.com/aws/aws-sdk-go-v2/aws/signer/v4+ - github.com/aws/aws-sdk-go-v2/internal/auth/smithy from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/internal/configsources from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/internal/context from github.com/aws/aws-sdk-go-v2/aws/retry+ - github.com/aws/aws-sdk-go-v2/internal/endpoints from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/internal/endpoints/awsrulesfn from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 from github.com/aws/aws-sdk-go-v2/service/sso/internal/endpoints+ - github.com/aws/aws-sdk-go-v2/internal/ini from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/internal/middleware from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/internal/rand from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/aws-sdk-go-v2/internal/sdk from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/aws-sdk-go-v2/internal/sdkio from github.com/aws/aws-sdk-go-v2/credentials/processcreds - github.com/aws/aws-sdk-go-v2/internal/shareddefaults from github.com/aws/aws-sdk-go-v2/config+ - github.com/aws/aws-sdk-go-v2/internal/strings from github.com/aws/aws-sdk-go-v2/aws/signer/internal/v4 - github.com/aws/aws-sdk-go-v2/internal/sync/singleflight from github.com/aws/aws-sdk-go-v2/aws - github.com/aws/aws-sdk-go-v2/internal/timeconv from github.com/aws/aws-sdk-go-v2/aws/retry - github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/aws-sdk-go-v2/service/sso from github.com/aws/aws-sdk-go-v2/config+ - github.com/aws/aws-sdk-go-v2/service/sso/internal/endpoints from github.com/aws/aws-sdk-go-v2/service/sso - github.com/aws/aws-sdk-go-v2/service/sso/types from github.com/aws/aws-sdk-go-v2/service/sso - github.com/aws/aws-sdk-go-v2/service/ssooidc from github.com/aws/aws-sdk-go-v2/config+ - github.com/aws/aws-sdk-go-v2/service/ssooidc/internal/endpoints from github.com/aws/aws-sdk-go-v2/service/ssooidc - github.com/aws/aws-sdk-go-v2/service/ssooidc/types from github.com/aws/aws-sdk-go-v2/service/ssooidc - github.com/aws/aws-sdk-go-v2/service/sts from github.com/aws/aws-sdk-go-v2/config+ - github.com/aws/aws-sdk-go-v2/service/sts/internal/endpoints from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/aws-sdk-go-v2/service/sts/types from github.com/aws/aws-sdk-go-v2/credentials/stscreds+ - github.com/aws/smithy-go from github.com/aws/aws-sdk-go-v2/aws/protocol/restjson+ - github.com/aws/smithy-go/auth from github.com/aws/aws-sdk-go-v2/internal/auth+ - github.com/aws/smithy-go/auth/bearer from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/smithy-go/context from github.com/aws/smithy-go/auth/bearer - github.com/aws/smithy-go/document from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/smithy-go/encoding from github.com/aws/smithy-go/encoding/json+ - github.com/aws/smithy-go/encoding/httpbinding from github.com/aws/aws-sdk-go-v2/aws/protocol/query+ - github.com/aws/smithy-go/encoding/json from github.com/aws/aws-sdk-go-v2/service/ssooidc - github.com/aws/smithy-go/encoding/xml from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/smithy-go/endpoints from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/smithy-go/endpoints/private/rulesfn from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/smithy-go/internal/sync/singleflight from github.com/aws/smithy-go/auth/bearer - github.com/aws/smithy-go/io from github.com/aws/aws-sdk-go-v2/feature/ec2/imds+ - github.com/aws/smithy-go/logging from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/smithy-go/metrics from github.com/aws/aws-sdk-go-v2/aws/retry+ - github.com/aws/smithy-go/middleware from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/smithy-go/private/requestcompression from github.com/aws/aws-sdk-go-v2/config - github.com/aws/smithy-go/ptr from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/smithy-go/rand from github.com/aws/aws-sdk-go-v2/aws/middleware - github.com/aws/smithy-go/time from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/smithy-go/tracing from github.com/aws/aws-sdk-go-v2/aws/middleware+ - github.com/aws/smithy-go/transport/http from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/smithy-go/transport/http/internal/io from github.com/aws/smithy-go/transport/http github.com/coder/websocket from tailscale.com/util/eventbus github.com/coder/websocket/internal/errd from github.com/coder/websocket github.com/coder/websocket/internal/util from github.com/coder/websocket @@ -105,7 +34,6 @@ tailscale.com/cmd/tsidp dependencies: (generated by github.com/tailscale/depawar L đŸ’Ŗ github.com/godbus/dbus/v5 from tailscale.com/net/dns github.com/golang/groupcache/lru from tailscale.com/net/dnscache github.com/google/btree from gvisor.dev/gvisor/pkg/tcpip/transport/tcp - D github.com/google/uuid from github.com/prometheus-community/pro-bing github.com/hdevalence/ed25519consensus from tailscale.com/tka github.com/huin/goupnp from github.com/huin/goupnp/dcps/internetgateway2+ github.com/huin/goupnp/dcps/internetgateway2 from tailscale.com/net/portmapper @@ -128,9 +56,8 @@ tailscale.com/cmd/tsidp dependencies: (generated by github.com/tailscale/depawar L đŸ’Ŗ github.com/mdlayher/socket from github.com/mdlayher/netlink+ đŸ’Ŗ github.com/mitchellh/go-ps from tailscale.com/safesocket github.com/pires/go-proxyproto from tailscale.com/ipn/ipnlocal - D github.com/prometheus-community/pro-bing from tailscale.com/wgengine/netstack L đŸ’Ŗ github.com/safchain/ethtool from tailscale.com/net/netkernelconf - W đŸ’Ŗ github.com/tailscale/certstore from tailscale.com/control/controlclient + DW đŸ’Ŗ github.com/tailscale/certstore from tailscale.com/control/controlclient W đŸ’Ŗ github.com/tailscale/go-winio from tailscale.com/safesocket W đŸ’Ŗ github.com/tailscale/go-winio/internal/fs from github.com/tailscale/go-winio W đŸ’Ŗ github.com/tailscale/go-winio/internal/socket from github.com/tailscale/go-winio @@ -223,11 +150,9 @@ tailscale.com/cmd/tsidp dependencies: (generated by github.com/tailscale/depawar tailscale.com/feature/buildfeatures from tailscale.com/wgengine/magicsock+ tailscale.com/feature/c2n from tailscale.com/tsnet tailscale.com/feature/condlite/expvar from tailscale.com/wgengine/magicsock - tailscale.com/feature/condregister/identityfederation from tailscale.com/tsnet tailscale.com/feature/condregister/oauthkey from tailscale.com/tsnet tailscale.com/feature/condregister/portmapper from tailscale.com/tsnet tailscale.com/feature/condregister/useproxy from tailscale.com/tsnet - tailscale.com/feature/identityfederation from tailscale.com/feature/condregister/identityfederation tailscale.com/feature/oauthkey from tailscale.com/feature/condregister/oauthkey tailscale.com/feature/portmapper from tailscale.com/feature/condregister/portmapper tailscale.com/feature/syspolicy from tailscale.com/logpolicy @@ -239,7 +164,7 @@ tailscale.com/cmd/tsidp dependencies: (generated by github.com/tailscale/depawar tailscale.com/ipn from tailscale.com/client/local+ tailscale.com/ipn/conffile from tailscale.com/ipn/ipnlocal+ đŸ’Ŗ tailscale.com/ipn/ipnauth from tailscale.com/ipn/ipnext+ - tailscale.com/ipn/ipnext from tailscale.com/ipn/ipnlocal + tailscale.com/ipn/ipnext from tailscale.com/ipn/ipnlocal+ tailscale.com/ipn/ipnlocal from tailscale.com/ipn/localapi+ tailscale.com/ipn/ipnlocal/netmapcache from tailscale.com/ipn/ipnlocal tailscale.com/ipn/ipnstate from tailscale.com/client/local+ @@ -309,12 +234,13 @@ tailscale.com/cmd/tsidp dependencies: (generated by github.com/tailscale/depawar tailscale.com/tstime from tailscale.com/control/controlclient+ tailscale.com/tstime/mono from tailscale.com/net/tstun+ tailscale.com/tstime/rate from tailscale.com/wgengine/filter - tailscale.com/tsweb from tailscale.com/util/eventbus + tailscale.com/tsweb from tailscale.com/util/eventbus+ tailscale.com/tsweb/varz from tailscale.com/tsweb+ tailscale.com/types/appctype from tailscale.com/ipn/ipnlocal+ tailscale.com/types/bools from tailscale.com/tsnet+ tailscale.com/types/dnstype from tailscale.com/client/local+ tailscale.com/types/empty from tailscale.com/ipn+ + tailscale.com/types/events from tailscale.com/control/controlclient+ tailscale.com/types/ipproto from tailscale.com/ipn+ tailscale.com/types/key from tailscale.com/client/local+ tailscale.com/types/lazy from tailscale.com/cmd/tsidp+ @@ -328,12 +254,12 @@ tailscale.com/cmd/tsidp dependencies: (generated by github.com/tailscale/depawar tailscale.com/types/opt from tailscale.com/cmd/tsidp+ tailscale.com/types/persist from tailscale.com/control/controlclient+ tailscale.com/types/preftype from tailscale.com/ipn+ - tailscale.com/types/ptr from tailscale.com/control/controlclient+ tailscale.com/types/result from tailscale.com/util/lineiter tailscale.com/types/structs from tailscale.com/control/controlclient+ tailscale.com/types/tkatype from tailscale.com/client/local+ tailscale.com/types/views from tailscale.com/appc+ tailscale.com/util/backoff from tailscale.com/control/controlclient+ + tailscale.com/util/bufiox from tailscale.com/types/key tailscale.com/util/checkchange from tailscale.com/ipn/ipnlocal+ tailscale.com/util/cibuild from tailscale.com/health+ tailscale.com/util/clientmetric from tailscale.com/appc+ @@ -398,12 +324,10 @@ tailscale.com/cmd/tsidp dependencies: (generated by github.com/tailscale/depawar tailscale.com/wgengine/wgcfg/nmcfg from tailscale.com/ipn/ipnlocal đŸ’Ŗ tailscale.com/wgengine/wgint from tailscale.com/wgengine+ tailscale.com/wgengine/wglog from tailscale.com/wgengine - tailscale.com/wif from tailscale.com/feature/identityfederation golang.org/x/crypto/argon2 from tailscale.com/tka golang.org/x/crypto/blake2b from golang.org/x/crypto/argon2+ golang.org/x/crypto/blake2s from github.com/tailscale/wireguard-go/device+ - LD golang.org/x/crypto/blowfish from golang.org/x/crypto/ssh/internal/bcrypt_pbkdf - golang.org/x/crypto/chacha20 from golang.org/x/crypto/chacha20poly1305+ + golang.org/x/crypto/chacha20 from golang.org/x/crypto/chacha20poly1305 golang.org/x/crypto/chacha20poly1305 from github.com/tailscale/wireguard-go/device+ golang.org/x/crypto/curve25519 from github.com/tailscale/wireguard-go/device+ golang.org/x/crypto/ed25519 from gopkg.in/square/go-jose.v2 @@ -415,24 +339,22 @@ tailscale.com/cmd/tsidp dependencies: (generated by github.com/tailscale/depawar golang.org/x/crypto/pbkdf2 from gopkg.in/square/go-jose.v2 golang.org/x/crypto/poly1305 from github.com/tailscale/wireguard-go/device golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+ - LD golang.org/x/crypto/ssh from tailscale.com/ipn/ipnlocal - LD golang.org/x/crypto/ssh/internal/bcrypt_pbkdf from golang.org/x/crypto/ssh golang.org/x/exp/constraints from tailscale.com/tsweb/varz+ golang.org/x/exp/maps from tailscale.com/ipn/store/mem+ golang.org/x/net/bpf from github.com/mdlayher/netlink+ golang.org/x/net/dns/dnsmessage from tailscale.com/appc+ golang.org/x/net/http/httpguts from tailscale.com/ipn/ipnlocal golang.org/x/net/http/httpproxy from tailscale.com/net/tshttpproxy - golang.org/x/net/icmp from github.com/prometheus-community/pro-bing+ + golang.org/x/net/icmp from tailscale.com/net/ping golang.org/x/net/idna from golang.org/x/net/http/httpguts+ golang.org/x/net/internal/iana from golang.org/x/net/icmp+ - golang.org/x/net/internal/socket from golang.org/x/net/icmp+ + golang.org/x/net/internal/socket from golang.org/x/net/ipv4+ golang.org/x/net/internal/socks from golang.org/x/net/proxy - golang.org/x/net/ipv4 from github.com/prometheus-community/pro-bing+ - golang.org/x/net/ipv6 from github.com/prometheus-community/pro-bing+ + golang.org/x/net/ipv4 from github.com/tailscale/wireguard-go/conn+ + golang.org/x/net/ipv6 from github.com/tailscale/wireguard-go/conn+ golang.org/x/net/proxy from tailscale.com/net/netns D golang.org/x/net/route from tailscale.com/net/netmon+ - golang.org/x/oauth2 from golang.org/x/oauth2/clientcredentials+ + golang.org/x/oauth2 from golang.org/x/oauth2/clientcredentials golang.org/x/oauth2/clientcredentials from tailscale.com/feature/oauthkey golang.org/x/oauth2/internal from golang.org/x/oauth2+ golang.org/x/sync/errgroup from github.com/mdlayher/socket+ @@ -477,7 +399,7 @@ tailscale.com/cmd/tsidp dependencies: (generated by github.com/tailscale/depawar crypto/aes from crypto/tls+ crypto/cipher from crypto/aes+ crypto/des from crypto/tls+ - crypto/dsa from crypto/x509+ + crypto/dsa from crypto/x509 crypto/ecdh from crypto/ecdsa+ crypto/ecdsa from crypto/tls+ crypto/ed25519 from crypto/tls+ @@ -526,21 +448,20 @@ tailscale.com/cmd/tsidp dependencies: (generated by github.com/tailscale/depawar crypto/internal/randutil from crypto/internal/rand crypto/internal/sysrand from crypto/internal/fips140/drbg crypto/md5 from crypto/tls+ - crypto/mlkem from golang.org/x/crypto/ssh+ + crypto/mlkem from crypto/hpke+ crypto/rand from crypto/ed25519+ - crypto/rc4 from crypto/tls+ + crypto/rc4 from crypto/tls crypto/rsa from crypto/tls+ crypto/sha1 from crypto/tls+ crypto/sha256 from crypto/tls+ crypto/sha3 from crypto/internal/fips140hash+ crypto/sha512 from crypto/ecdsa+ crypto/subtle from crypto/cipher+ - crypto/tls from github.com/prometheus-community/pro-bing+ + crypto/tls from net/http+ crypto/tls/internal/fips140tls from crypto/tls crypto/x509 from crypto/tls+ D crypto/x509/internal/macos from crypto/x509 crypto/x509/pkix from crypto/x509+ - D database/sql/driver from github.com/google/uuid W debug/dwarf from debug/pe W debug/pe from github.com/dblohm7/wingoes/pe embed from github.com/tailscale/web-client-prebuilt+ @@ -629,7 +550,7 @@ tailscale.com/cmd/tsidp dependencies: (generated by github.com/tailscale/depawar mime/quotedprintable from mime/multipart net from crypto/tls+ net/http from expvar+ - net/http/httptrace from github.com/prometheus-community/pro-bing+ + net/http/httptrace from net/http+ net/http/httputil from tailscale.com/client/web+ net/http/internal from net/http+ net/http/internal/ascii from net/http+ @@ -644,7 +565,7 @@ tailscale.com/cmd/tsidp dependencies: (generated by github.com/tailscale/depawar os/user from github.com/godbus/dbus/v5+ path from debug/dwarf+ path/filepath from crypto/x509+ - reflect from database/sql/driver+ + reflect from encoding/asn1+ regexp from github.com/huin/goupnp/httpu+ regexp/syntax from regexp runtime from crypto/internal/fips140+ diff --git a/cmd/tsidp/tsidp_test.go b/cmd/tsidp/tsidp_test.go index 26c906fab216b..baf5aaee14fa8 100644 --- a/cmd/tsidp/tsidp_test.go +++ b/cmd/tsidp/tsidp_test.go @@ -137,14 +137,14 @@ func TestFlattenExtraClaims(t *testing.T) { expected map[string]any }{ { - name: "empty extra claims", + name: "empty-extra-claims", input: []capRule{ {ExtraClaims: map[string]any{}}, }, expected: map[string]any{}, }, { - name: "string and number values", + name: "string-and-number-values", input: []capRule{ { ExtraClaims: map[string]any{ @@ -159,7 +159,7 @@ func TestFlattenExtraClaims(t *testing.T) { }, }, { - name: "slice of strings and ints", + name: "slice-of-strings-and-ints", input: []capRule{ { ExtraClaims: map[string]any{ @@ -172,7 +172,8 @@ func TestFlattenExtraClaims(t *testing.T) { }, }, { - name: "duplicate values deduplicated (slice input)", + // duplicate values deduplicated via slice input + name: "dedup-slice-input", input: []capRule{ { ExtraClaims: map[string]any{ @@ -190,7 +191,8 @@ func TestFlattenExtraClaims(t *testing.T) { }, }, { - name: "ignore unsupported map type, keep valid scalar", + // ignore unsupported map type, keep valid scalar + name: "ignore-unsupported-map-keep-scalar", input: []capRule{ { ExtraClaims: map[string]any{ @@ -204,7 +206,7 @@ func TestFlattenExtraClaims(t *testing.T) { }, }, { - name: "scalar first, slice second", + name: "scalar-first-slice-second", input: []capRule{ {ExtraClaims: map[string]any{"foo": "bar"}}, {ExtraClaims: map[string]any{"foo": []any{"baz"}}}, @@ -214,7 +216,7 @@ func TestFlattenExtraClaims(t *testing.T) { }, }, { - name: "conflicting scalar and unsupported map", + name: "conflicting-scalar-and-unsupported-map", input: []capRule{ {ExtraClaims: map[string]any{"foo": "bar"}}, {ExtraClaims: map[string]any{"foo": map[string]any{"bad": "entry"}}}, @@ -224,7 +226,7 @@ func TestFlattenExtraClaims(t *testing.T) { }, }, { - name: "multiple slices with overlap", + name: "multiple-slices-with-overlap", input: []capRule{ {ExtraClaims: map[string]any{"roles": []any{"admin", "user"}}}, {ExtraClaims: map[string]any{"roles": []any{"admin", "guest"}}}, @@ -234,7 +236,7 @@ func TestFlattenExtraClaims(t *testing.T) { }, }, { - name: "slice with unsupported values", + name: "slice-with-unsupported-values", input: []capRule{ {ExtraClaims: map[string]any{ "mixed": []any{"ok", 42, map[string]string{"oops": "fail"}}, @@ -245,7 +247,7 @@ func TestFlattenExtraClaims(t *testing.T) { }, }, { - name: "duplicate scalar value", + name: "duplicate-scalar-value", input: []capRule{ {ExtraClaims: map[string]any{"env": "prod"}}, {ExtraClaims: map[string]any{"env": "prod"}}, @@ -279,7 +281,7 @@ func TestExtraClaims(t *testing.T) { expectError bool }{ { - name: "extra claim", + name: "extra-claim", claim: tailscaleClaims{ Claims: jwt.Claims{}, Nonce: "foobar", @@ -312,7 +314,7 @@ func TestExtraClaims(t *testing.T) { }, }, { - name: "duplicate claim distinct values", + name: "duplicate-claim-distinct-values", claim: tailscaleClaims{ Claims: jwt.Claims{}, Nonce: "foobar", @@ -350,7 +352,7 @@ func TestExtraClaims(t *testing.T) { }, }, { - name: "multiple extra claims", + name: "multiple-extra-claims", claim: tailscaleClaims{ Claims: jwt.Claims{}, Nonce: "foobar", @@ -389,7 +391,7 @@ func TestExtraClaims(t *testing.T) { }, }, { - name: "overwrite claim", + name: "overwrite-claim", claim: tailscaleClaims{ Claims: jwt.Claims{}, Nonce: "foobar", @@ -422,7 +424,7 @@ func TestExtraClaims(t *testing.T) { expectError: true, }, { - name: "empty extra claims", + name: "empty-extra-claims", claim: tailscaleClaims{ Claims: jwt.Claims{}, Nonce: "foobar", @@ -496,21 +498,21 @@ func TestServeToken(t *testing.T) { expected map[string]any }{ { - name: "GET not allowed", + name: "GET-not-allowed", method: "GET", grantType: "authorization_code", strictMode: false, expectError: true, }, { - name: "unsupported grant type", + name: "unsupported-grant-type", method: "POST", grantType: "pkcs", strictMode: false, expectError: true, }, { - name: "invalid code", + name: "invalid-code", method: "POST", grantType: "authorization_code", code: "invalid-code", @@ -518,7 +520,7 @@ func TestServeToken(t *testing.T) { expectError: true, }, { - name: "omit code from form", + name: "omit-code-from-form", method: "POST", grantType: "authorization_code", omitCode: true, @@ -526,7 +528,7 @@ func TestServeToken(t *testing.T) { expectError: true, }, { - name: "invalid redirect uri", + name: "invalid-redirect-uri", method: "POST", grantType: "authorization_code", code: "valid-code", @@ -536,7 +538,7 @@ func TestServeToken(t *testing.T) { expectError: true, }, { - name: "invalid remoteAddr", + name: "invalid-remoteAddr", method: "POST", grantType: "authorization_code", redirectURI: "https://rp.example.com/callback", @@ -546,7 +548,7 @@ func TestServeToken(t *testing.T) { expectError: true, }, { - name: "extra claim included (non-strict)", + name: "extra-claim-included-non-strict", method: "POST", grantType: "authorization_code", redirectURI: "https://rp.example.com/callback", @@ -568,7 +570,8 @@ func TestServeToken(t *testing.T) { }, }, { - name: "attempt to overwrite protected claim (non-strict)", + // attempt to overwrite protected claim in non-strict mode + name: "overwrite-protected-claim-non-strict", method: "POST", grantType: "authorization_code", redirectURI: "https://rp.example.com/callback", @@ -708,7 +711,7 @@ func TestExtraUserInfo(t *testing.T) { expectError bool }{ { - name: "extra claim", + name: "extra-claim", tokenValidTill: time.Now().Add(1 * time.Minute), caps: tailcfg.PeerCapMap{ tailcfg.PeerCapabilityTsIDP: { @@ -725,7 +728,7 @@ func TestExtraUserInfo(t *testing.T) { }, }, { - name: "duplicate claim distinct values", + name: "duplicate-claim-distinct-values", tokenValidTill: time.Now().Add(1 * time.Minute), caps: tailcfg.PeerCapMap{ tailcfg.PeerCapabilityTsIDP: { @@ -742,7 +745,7 @@ func TestExtraUserInfo(t *testing.T) { }, }, { - name: "multiple extra claims", + name: "multiple-extra-claims", tokenValidTill: time.Now().Add(1 * time.Minute), caps: tailcfg.PeerCapMap{ tailcfg.PeerCapabilityTsIDP: { @@ -761,13 +764,13 @@ func TestExtraUserInfo(t *testing.T) { }, }, { - name: "empty extra claims", + name: "empty-extra-claims", caps: tailcfg.PeerCapMap{}, tokenValidTill: time.Now().Add(1 * time.Minute), expected: map[string]any{}, }, { - name: "attempt to overwrite protected claim", + name: "overwrite-protected-claim", tokenValidTill: time.Now().Add(1 * time.Minute), caps: tailcfg.PeerCapMap{ tailcfg.PeerCapabilityTsIDP: { @@ -783,7 +786,7 @@ func TestExtraUserInfo(t *testing.T) { expectError: true, }, { - name: "extra claim omitted", + name: "extra-claim-omitted", tokenValidTill: time.Now().Add(1 * time.Minute), caps: tailcfg.PeerCapMap{ tailcfg.PeerCapabilityTsIDP: { @@ -798,7 +801,7 @@ func TestExtraUserInfo(t *testing.T) { expected: map[string]any{}, }, { - name: "expired token", + name: "expired-token", caps: tailcfg.PeerCapMap{}, tokenValidTill: time.Now().Add(-1 * time.Minute), expected: map[string]any{}, @@ -1131,19 +1134,22 @@ func TestGetAllowInsecureRegistration(t *testing.T) { expectAllowInsecureRegistration bool }{ { - name: "flag explicitly set to false - insecure registration disabled (strict mode)", + // flag explicitly set to false - insecure registration disabled (strict mode) + name: "flag-false-insecure-disabled", flagSet: true, flagValue: false, expectAllowInsecureRegistration: false, }, { - name: "flag explicitly set to true - insecure registration enabled", + // flag explicitly set to true - insecure registration enabled + name: "flag-true-insecure-enabled", flagSet: true, flagValue: true, expectAllowInsecureRegistration: true, }, { - name: "flag unset - insecure registration enabled (default for backward compatibility)", + // flag unset - insecure registration enabled (default for backward compatibility) + name: "flag-unset-insecure-enabled-default", flagSet: false, flagValue: false, // not used when unset expectAllowInsecureRegistration: true, @@ -1192,7 +1198,7 @@ func TestMigrateOAuthClients(t *testing.T) { expectOldRenamed bool }{ { - name: "migrate from old file to new file", + name: "migrate-old-to-new", setupOldFile: true, oldFileContent: map[string]*funnelClient{ "old-client": { @@ -1206,7 +1212,7 @@ func TestMigrateOAuthClients(t *testing.T) { expectOldRenamed: true, }, { - name: "new file already exists - no migration", + name: "new-file-exists-no-migration", setupNewFile: true, newFileContent: map[string]*funnelClient{ "existing-client": { @@ -1220,12 +1226,12 @@ func TestMigrateOAuthClients(t *testing.T) { expectOldRenamed: false, }, { - name: "neither file exists - create empty new file", + name: "neither-exists-create-empty", expectNewFileExists: true, expectOldRenamed: false, }, { - name: "both files exist - prefer new file", + name: "both-exist-prefer-new", setupOldFile: true, setupNewFile: true, oldFileContent: map[string]*funnelClient{ @@ -1373,19 +1379,19 @@ func TestGetConfigFilePath(t *testing.T) { expectError bool }{ { - name: "file exists in current directory - use current directory", + name: "file-in-cwd-use-cwd", fileName: "test-config.json", createInCwd: true, expectInCwd: true, }, { - name: "file does not exist - use root path", + name: "file-missing-use-root", fileName: "test-config.json", createInCwd: false, expectInCwd: false, }, { - name: "file exists in both - prefer current directory", + name: "file-in-both-prefer-cwd", fileName: "test-config.json", createInCwd: true, createInRoot: true, @@ -1472,7 +1478,7 @@ func TestAuthorizeStrictMode(t *testing.T) { }{ // Security boundary test: funnel rejection { - name: "funnel requests are always rejected for security", + name: "funnel-rejected", strictMode: true, clientID: "test-client", redirectURI: "https://rp.example.com/callback", @@ -1487,7 +1493,7 @@ func TestAuthorizeStrictMode(t *testing.T) { // Strict mode parameter validation tests (non-funnel) { - name: "strict mode - missing client_id", + name: "strict-missing-client_id", strictMode: true, clientID: "", redirectURI: "https://rp.example.com/callback", @@ -1496,7 +1502,7 @@ func TestAuthorizeStrictMode(t *testing.T) { expectCode: http.StatusBadRequest, }, { - name: "strict mode - missing redirect_uri", + name: "strict-missing-redirect_uri", strictMode: true, clientID: "test-client", redirectURI: "", @@ -1507,7 +1513,7 @@ func TestAuthorizeStrictMode(t *testing.T) { // Strict mode client validation tests (non-funnel) { - name: "strict mode - invalid client_id", + name: "strict-invalid-client_id", strictMode: true, clientID: "invalid-client", redirectURI: "https://rp.example.com/callback", @@ -1517,7 +1523,7 @@ func TestAuthorizeStrictMode(t *testing.T) { expectCode: http.StatusBadRequest, }, { - name: "strict mode - redirect_uri mismatch", + name: "strict-redirect_uri-mismatch", strictMode: true, clientID: "test-client", redirectURI: "https://wrong.example.com/callback", @@ -1666,7 +1672,7 @@ func TestServeTokenWithClientValidation(t *testing.T) { expectIDToken bool }{ { - name: "strict mode - valid token exchange with form credentials", + name: "strict-token-exchange-form-creds", strictMode: true, method: "POST", grantType: "authorization_code", @@ -1680,7 +1686,7 @@ func TestServeTokenWithClientValidation(t *testing.T) { expectIDToken: true, }, { - name: "strict mode - valid token exchange with basic auth", + name: "strict-token-exchange-basic-auth", strictMode: true, method: "POST", grantType: "authorization_code", @@ -1695,7 +1701,7 @@ func TestServeTokenWithClientValidation(t *testing.T) { expectIDToken: true, }, { - name: "strict mode - missing client credentials", + name: "strict-missing-client-creds", strictMode: true, method: "POST", grantType: "authorization_code", @@ -1708,7 +1714,7 @@ func TestServeTokenWithClientValidation(t *testing.T) { expectCode: http.StatusUnauthorized, }, { - name: "strict mode - client_id mismatch", + name: "strict-client_id-mismatch", strictMode: true, method: "POST", grantType: "authorization_code", @@ -1722,7 +1728,7 @@ func TestServeTokenWithClientValidation(t *testing.T) { expectCode: http.StatusBadRequest, }, { - name: "strict mode - invalid client secret", + name: "strict-invalid-client-secret", strictMode: true, method: "POST", grantType: "authorization_code", @@ -1737,7 +1743,7 @@ func TestServeTokenWithClientValidation(t *testing.T) { expectCode: http.StatusUnauthorized, }, { - name: "strict mode - redirect_uri mismatch", + name: "strict-redirect_uri-mismatch", strictMode: true, method: "POST", grantType: "authorization_code", @@ -1752,7 +1758,7 @@ func TestServeTokenWithClientValidation(t *testing.T) { expectCode: http.StatusBadRequest, }, { - name: "non-strict mode - no client validation required", + name: "non-strict-no-client-validation", strictMode: false, method: "POST", grantType: "authorization_code", @@ -1913,7 +1919,7 @@ func TestServeUserInfoWithClientValidation(t *testing.T) { expectUserInfo bool }{ { - name: "strict mode - valid token with existing client", + name: "strict-valid-token-existing-client", strictMode: true, setupToken: true, setupClient: true, @@ -1923,7 +1929,8 @@ func TestServeUserInfoWithClientValidation(t *testing.T) { expectUserInfo: true, }, { - name: "strict mode - valid token but client no longer exists", + // valid token but client no longer exists + name: "strict-token-client-deleted", strictMode: true, setupToken: true, setupClient: false, @@ -1934,7 +1941,7 @@ func TestServeUserInfoWithClientValidation(t *testing.T) { expectCode: http.StatusUnauthorized, }, { - name: "strict mode - expired token", + name: "strict-expired-token", strictMode: true, setupToken: true, setupClient: true, @@ -1945,7 +1952,7 @@ func TestServeUserInfoWithClientValidation(t *testing.T) { expectCode: http.StatusBadRequest, }, { - name: "strict mode - invalid token", + name: "strict-invalid-token", strictMode: true, setupToken: false, token: "invalid-token", @@ -1953,7 +1960,7 @@ func TestServeUserInfoWithClientValidation(t *testing.T) { expectCode: http.StatusBadRequest, }, { - name: "strict mode - token without client association", + name: "strict-token-no-client-assoc", strictMode: true, setupToken: true, setupClient: false, @@ -1964,7 +1971,7 @@ func TestServeUserInfoWithClientValidation(t *testing.T) { expectCode: http.StatusBadRequest, }, { - name: "non-strict mode - no client validation required", + name: "non-strict-no-client-validation", strictMode: false, setupToken: true, setupClient: false, diff --git a/cmd/tsnet-proxy/tsnet-proxy.go b/cmd/tsnet-proxy/tsnet-proxy.go new file mode 100644 index 0000000000000..0a83fd1a8dac8 --- /dev/null +++ b/cmd/tsnet-proxy/tsnet-proxy.go @@ -0,0 +1,173 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// The tsnet-proxy command exposes a local port on the tailnet under a +// chosen hostname. By default it proxies raw TCP; pass --http to reverse +// proxy as HTTP, or --https to reverse proxy as HTTPS with an auto-issued +// Tailscale cert. Both HTTP modes inject Tailscale-User-* identity headers +// from WhoIs. +// +// Arguments are [tailnet]: local is the port on localhost +// to proxy to and tailnet is the port to expose on the tailnet. If tailnet +// is omitted, it defaults to 443 for --https, 80 for --http, and the local +// port otherwise. +// +// go run ./cmd/tsnet-proxy myapp 8080 # raw TCP, tailnet :8080 +// go run ./cmd/tsnet-proxy myapp 22 2222 # raw TCP, tailnet :2222 +// go run ./cmd/tsnet-proxy --http myapp 8080 # tailnet :80 +// go run ./cmd/tsnet-proxy --https myapp 8080 # tailnet :443 +// +// Or run directly from the module, no checkout required: +// +// go run tailscale.com/cmd/tsnet-proxy@latest myapp 8080 +package main + +import ( + "flag" + "fmt" + "io" + "log" + "mime" + "net" + "net/http" + "net/http/httputil" + "net/url" + "os" + "strconv" + "unicode/utf8" + + "tailscale.com/client/local" + "tailscale.com/tsnet" +) + +func main() { + asHTTP := flag.Bool("http", false, "reverse proxy as HTTP and inject Tailscale-User-* headers") + asHTTPS := flag.Bool("https", false, "reverse proxy as HTTPS with an auto-issued Tailscale cert; implies --http") + dir := flag.String("dir", "", "directory to persist tsnet state (default: per-user config dir)") + verbose := flag.Bool("v", false, "verbose tsnet backend logs") + flag.Usage = func() { + fmt.Fprintf(flag.CommandLine.Output(), "usage: %s [flags] [tailnet]\n", flag.CommandLine.Name()) + flag.PrintDefaults() + } + flag.Parse() + + if n := flag.NArg(); n != 2 && n != 3 { + flag.Usage() + os.Exit(2) + } + name := flag.Arg(0) + localPort, err := parsePort(flag.Arg(1)) + if err != nil { + log.Fatalf("invalid local port %q: %v", flag.Arg(1), err) + } + tailnetPort := defaultTailnetPort(localPort, *asHTTP, *asHTTPS) + if flag.NArg() == 3 { + tailnetPort, err = parsePort(flag.Arg(2)) + if err != nil { + log.Fatalf("invalid tailnet port %q: %v", flag.Arg(2), err) + } + } + + target := "localhost:" + strconv.Itoa(localPort) + addr := ":" + strconv.Itoa(tailnetPort) + + s := &tsnet.Server{Hostname: name, Dir: *dir} + if *verbose { + s.Logf = log.Printf + } + defer s.Close() + + var ln net.Listener + if *asHTTPS { + ln, err = s.ListenTLS("tcp", addr) + } else { + ln, err = s.Listen("tcp", addr) + } + if err != nil { + log.Fatal(err) + } + defer ln.Close() + + log.Printf("proxying %s -> %s on tailnet", target, name+addr) + + if *asHTTP || *asHTTPS { + lc, err := s.LocalClient() + if err != nil { + log.Fatal(err) + } + targetURL := &url.URL{Scheme: "http", Host: target} + rp := &httputil.ReverseProxy{ + Rewrite: func(r *httputil.ProxyRequest) { + r.SetURL(targetURL) + r.SetXForwarded() + addTailscaleIdentityHeaders(lc, r) + }, + } + log.Fatal(http.Serve(ln, rp)) + } + + for { + c, err := ln.Accept() + if err != nil { + log.Fatal(err) + } + go proxyTCP(c, target) + } +} + +func parsePort(s string) (int, error) { + p, err := strconv.Atoi(s) + if err != nil || p <= 0 || p > 65535 { + return 0, fmt.Errorf("bad port") + } + return p, nil +} + +// defaultTailnetPort returns the tailnet port when the user didn't +// specify one: 443 for HTTPS, 80 for HTTP, else the local port. +func defaultTailnetPort(local int, asHTTP, asHTTPS bool) int { + switch { + case asHTTPS: + return 443 + case asHTTP: + return 80 + } + return local +} + +func proxyTCP(c net.Conn, target string) { + defer c.Close() + d, err := net.Dial("tcp", target) + if err != nil { + log.Printf("dial %s: %v", target, err) + return + } + defer d.Close() + go io.Copy(d, c) + io.Copy(c, d) +} + +func addTailscaleIdentityHeaders(lc *local.Client, r *httputil.ProxyRequest) { + r.Out.Header.Del("Tailscale-User-Login") + r.Out.Header.Del("Tailscale-User-Name") + r.Out.Header.Del("Tailscale-User-Profile-Pic") + r.Out.Header.Del("Tailscale-Funnel-Request") + r.Out.Header.Del("Tailscale-Headers-Info") + + who, err := lc.WhoIs(r.In.Context(), r.In.RemoteAddr) + if err != nil || who == nil || who.Node.IsTagged() { + return + } + r.Out.Header.Set("Tailscale-User-Login", encHeader(who.UserProfile.LoginName)) + r.Out.Header.Set("Tailscale-User-Name", encHeader(who.UserProfile.DisplayName)) + r.Out.Header.Set("Tailscale-User-Profile-Pic", who.UserProfile.ProfilePicURL) +} + +// encHeader mirrors the encoding tailscaled's serve path applies to +// user-provided strings destined for HTTP headers. +func encHeader(v string) string { + if !utf8.ValidString(v) { + return "" + } + return mime.QEncoding.Encode("utf-8", v) +} diff --git a/cmd/tsp/tsp.go b/cmd/tsp/tsp.go new file mode 100644 index 0000000000000..a59b352d5b060 --- /dev/null +++ b/cmd/tsp/tsp.go @@ -0,0 +1,513 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// Program tsp is a low-level Tailscale protocol tool for performing +// composable building block operations like generating keys and +// registering nodes. +package main + +import ( + "bytes" + "cmp" + "context" + "encoding/json" + "errors" + "flag" + "fmt" + "io" + "os" + "reflect" + "strings" + + "github.com/peterbourgon/ff/v3/ffcli" + "tailscale.com/control/tsp" + "tailscale.com/hostinfo" + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +var globalArgs struct { + // serverURL is the base URL of the coordination server (-s flag). + // If empty, tsp.DefaultServerURL is used. + serverURL string + + // controlKeyFile is a path to a file containing the server's + // MachinePublic key in MarshalText form (--control-key flag). + // When set, server key discovery is skipped. + controlKeyFile string +} + +func main() { + args := os.Args[1:] + if err := rootCmd.Parse(args); err != nil { + fmt.Fprintln(os.Stderr, err.Error()) + os.Exit(1) + } + err := rootCmd.Run(context.Background()) + if errors.Is(err, flag.ErrHelp) { + os.Exit(0) + } + if err != nil { + fmt.Fprintln(os.Stderr, err.Error()) + os.Exit(1) + } +} + +var rootCmd = &ffcli.Command{ + Name: "tsp", + ShortUsage: "tsp [-s url] [flags]", + ShortHelp: "Low-level Tailscale protocol tool.", + FlagSet: (func() *flag.FlagSet { + fs := flag.NewFlagSet("tsp", flag.ExitOnError) + fs.StringVar(&globalArgs.serverURL, "s", "", "base URL of coordination server (default: "+tsp.DefaultServerURL+")") + fs.StringVar(&globalArgs.controlKeyFile, "control-key", "", "file containing the server's public key (skips discovery)") + return fs + })(), + Subcommands: []*ffcli.Command{ + newMachineKeyCmd, + newNodeKeyCmd, + newNodeCmd, + registerCmd, + mapCmd, + discoverServerKeyCmd, + }, + Exec: func(ctx context.Context, args []string) error { + return flag.ErrHelp + }, +} + +var newMachineKeyArgs struct { + output string +} + +var newMachineKeyCmd = &ffcli.Command{ + Name: "new-machine-key", + ShortUsage: "tsp new-machine-key [-o file]", + ShortHelp: "Generate a new machine key.", + FlagSet: (func() *flag.FlagSet { + fs := flag.NewFlagSet("new-machine-key", flag.ExitOnError) + fs.StringVar(&newMachineKeyArgs.output, "o", "", "output file (default: stdout)") + return fs + })(), + Exec: runNewMachineKey, +} + +func runNewMachineKey(ctx context.Context, args []string) error { + k := key.NewMachine() + text, err := k.MarshalText() + if err != nil { + return err + } + text = append(text, '\n') + return writeOutput(newMachineKeyArgs.output, text) +} + +var newNodeKeyArgs struct { + output string +} + +var newNodeKeyCmd = &ffcli.Command{ + Name: "new-node-key", + ShortUsage: "tsp new-node-key [-o file]", + ShortHelp: "Generate a new node key.", + FlagSet: (func() *flag.FlagSet { + fs := flag.NewFlagSet("new-node-key", flag.ExitOnError) + fs.StringVar(&newNodeKeyArgs.output, "o", "", "output file (default: stdout)") + return fs + })(), + Exec: runNewNodeKey, +} + +func runNewNodeKey(ctx context.Context, args []string) error { + k := key.NewNode() + text, err := k.MarshalText() + if err != nil { + return err + } + text = append(text, '\n') + return writeOutput(newNodeKeyArgs.output, text) +} + +var discoverServerKeyArgs struct { + output string +} + +var discoverServerKeyCmd = &ffcli.Command{ + Name: "discover-server-key", + ShortUsage: "tsp [-s url] discover-server-key [-o file]", + ShortHelp: "Discover and print the coordination server's public key.", + FlagSet: (func() *flag.FlagSet { + fs := flag.NewFlagSet("discover-server-key", flag.ExitOnError) + fs.StringVar(&discoverServerKeyArgs.output, "o", "", "output file (default: stdout)") + return fs + })(), + Exec: runDiscoverServerKey, +} + +func runDiscoverServerKey(ctx context.Context, args []string) error { + k, err := tsp.DiscoverServerKey(ctx, globalArgs.serverURL) + if err != nil { + return err + } + text, err := k.MarshalText() + if err != nil { + return fmt.Errorf("marshaling server key: %w", err) + } + text = append(text, '\n') + return writeOutput(discoverServerKeyArgs.output, text) +} + +var newNodeArgs struct { + nodeKeyFile string + machineKeyFile string + output string +} + +var newNodeCmd = &ffcli.Command{ + Name: "new-node", + ShortUsage: "tsp [-s url] [--control-key file] new-node [-n node-key-file] [-m machine-key-file] [-o output]", + ShortHelp: "Generate a new node JSON file with keys and server info.", + FlagSet: (func() *flag.FlagSet { + fs := flag.NewFlagSet("new-node", flag.ExitOnError) + fs.StringVar(&newNodeArgs.nodeKeyFile, "n", "", "existing node key file (default: generate new)") + fs.StringVar(&newNodeArgs.machineKeyFile, "m", "", "existing machine key file (default: generate new)") + fs.StringVar(&newNodeArgs.output, "o", "", "output file (default: stdout)") + return fs + })(), + Exec: runNewNode, +} + +func runNewNode(ctx context.Context, args []string) error { + var nodeKey key.NodePrivate + if newNodeArgs.nodeKeyFile != "" { + var err error + nodeKey, err = readNodeKeyFile(newNodeArgs.nodeKeyFile) + if err != nil { + return fmt.Errorf("reading node key: %w", err) + } + } else { + nodeKey = key.NewNode() + } + + var machineKey key.MachinePrivate + if newNodeArgs.machineKeyFile != "" { + var err error + machineKey, err = readMachineKeyFile(newNodeArgs.machineKeyFile) + if err != nil { + return fmt.Errorf("reading machine key: %w", err) + } + } else { + machineKey = key.NewMachine() + } + + serverURL := cmp.Or(globalArgs.serverURL, tsp.DefaultServerURL) + + var serverKey key.MachinePublic + if globalArgs.controlKeyFile != "" { + var err error + serverKey, err = readControlKeyFile(globalArgs.controlKeyFile) + if err != nil { + return fmt.Errorf("reading control key: %w", err) + } + } else { + var err error + serverKey, err = tsp.DiscoverServerKey(ctx, serverURL) + if err != nil { + return fmt.Errorf("discovering server key: %w", err) + } + } + + nf := tsp.NodeFile{ + NodeKey: nodeKey, + MachineKey: machineKey, + ServerInfo: tsp.ServerInfo{URL: serverURL, Key: serverKey}, + } + + out, err := json.MarshalIndent(nf, "", " ") + if err != nil { + return fmt.Errorf("encoding node file: %w", err) + } + out = append(out, '\n') + return writeOutput(newNodeArgs.output, out) +} + +var registerArgs struct { + nodeFile string + output string + hostname string + ephemeral bool + authKey string + tags string +} + +var registerCmd = &ffcli.Command{ + Name: "register", + ShortUsage: "tsp [-s url] register -n [flags]", + ShortHelp: "Register a node key with a coordination server.", + FlagSet: (func() *flag.FlagSet { + fs := flag.NewFlagSet("register", flag.ExitOnError) + fs.StringVar(®isterArgs.nodeFile, "n", "", "node JSON file (required)") + fs.StringVar(®isterArgs.output, "o", "", "output file (default: stdout)") + fs.StringVar(®isterArgs.hostname, "hostname", "", "hostname to register") + fs.BoolVar(®isterArgs.ephemeral, "ephemeral", false, "register as ephemeral node") + fs.StringVar(®isterArgs.authKey, "auth-key", "", "pre-authorized auth key or file containing one") + fs.StringVar(®isterArgs.tags, "tags", "", "comma-separated ACL tags") + return fs + })(), + Exec: runRegister, +} + +func runRegister(ctx context.Context, args []string) error { + if registerArgs.nodeFile == "" { + return fmt.Errorf("flag -n (node file) is required") + } + + nf, err := tsp.ReadNodeFile(registerArgs.nodeFile) + if err != nil { + return fmt.Errorf("reading node file: %w", err) + } + + hi := hostinfo.New() + if registerArgs.hostname != "" { + hi.Hostname = registerArgs.hostname + } + + var tags []string + if registerArgs.tags != "" { + tags = strings.Split(registerArgs.tags, ",") + } + + authKey, err := resolveAuthKey(registerArgs.authKey) + if err != nil { + return err + } + + client, err := tsp.NewClient(tsp.ClientOpts{ + ServerURL: cmp.Or(globalArgs.serverURL, nf.URL), + MachineKey: nf.MachineKey, + }) + if err != nil { + return fmt.Errorf("creating client: %w", err) + } + defer client.Close() + + if globalArgs.controlKeyFile != "" { + controlKey, err := readControlKeyFile(globalArgs.controlKeyFile) + if err != nil { + return fmt.Errorf("reading control key: %w", err) + } + client.SetControlPublicKey(controlKey) + } else { + client.SetControlPublicKey(nf.ServerInfo.Key) + } + + resp, err := client.Register(ctx, tsp.RegisterOpts{ + NodeKey: nf.NodeKey, + Hostinfo: hi, + Ephemeral: registerArgs.ephemeral, + AuthKey: authKey, + Tags: tags, + }) + if err != nil { + return err + } + + out, err := json.MarshalIndent(resp, "", " ") + if err != nil { + return fmt.Errorf("encoding response: %w", err) + } + out = append(out, '\n') + + if err := writeOutput(registerArgs.output, out); err != nil { + return err + } + + if resp.AuthURL != "" { + fmt.Fprintf(os.Stderr, "AuthURL: %s\n", resp.AuthURL) + } + return nil +} + +var mapArgs struct { + nodeFile string + stream bool + peers bool + quiet bool + output string +} + +var mapCmd = &ffcli.Command{ + Name: "map", + ShortUsage: "tsp [-s url] map -n [-stream]", + ShortHelp: "Send a map request to the coordination server.", + FlagSet: (func() *flag.FlagSet { + fs := flag.NewFlagSet("map", flag.ExitOnError) + fs.StringVar(&mapArgs.nodeFile, "n", "", "node JSON file (required)") + fs.BoolVar(&mapArgs.stream, "stream", false, "stream map responses") + fs.BoolVar(&mapArgs.peers, "peers", true, "include peers in map response") + fs.BoolVar(&mapArgs.quiet, "quiet", true, "suppress keepalives and handled c2n ping requests from output") + fs.StringVar(&mapArgs.output, "o", "", "output file (default: stdout)") + return fs + })(), + Exec: runMap, +} + +func runMap(ctx context.Context, args []string) error { + if mapArgs.nodeFile == "" { + return fmt.Errorf("flag -n (node file) is required") + } + + nf, err := tsp.ReadNodeFile(mapArgs.nodeFile) + if err != nil { + return fmt.Errorf("reading node file: %w", err) + } + + if globalArgs.serverURL != "" && globalArgs.serverURL != nf.URL { + return fmt.Errorf("server URL mismatch: -s flag is %q but node file is for %q", globalArgs.serverURL, nf.URL) + } + + hi := hostinfo.New() + + client, err := tsp.NewClient(tsp.ClientOpts{ + ServerURL: cmp.Or(globalArgs.serverURL, nf.URL), + MachineKey: nf.MachineKey, + }) + if err != nil { + return fmt.Errorf("creating client: %w", err) + } + defer client.Close() + + if globalArgs.controlKeyFile != "" { + controlKey, err := readControlKeyFile(globalArgs.controlKeyFile) + if err != nil { + return fmt.Errorf("reading control key: %w", err) + } + client.SetControlPublicKey(controlKey) + } else { + client.SetControlPublicKey(nf.ServerInfo.Key) + } + + session, err := client.Map(ctx, tsp.MapOpts{ + NodeKey: nf.NodeKey, + Hostinfo: hi, + Stream: mapArgs.stream, + OmitPeers: !mapArgs.peers, + }) + if err != nil { + return err + } + defer session.Close() + + gotResponse := false + for { + resp, err := session.Next() + if err == io.EOF { + if !gotResponse { + return fmt.Errorf("server returned no map response") + } + return nil + } + if err != nil { + return fmt.Errorf("reading map response: %w", err) + } + gotResponse = true + + if pr := resp.PingRequest; pr != nil && pr.Types == "c2n" { + if client.AnswerC2NPing(ctx, pr, session.NoiseRoundTrip) && mapArgs.quiet { + resp.PingRequest = nil + } + } + if mapArgs.quiet { + resp.KeepAlive = false + } + + if isZeroMapResponse(resp) { + continue + } + + out, err := json.MarshalIndent(resp, "", " ") + if err != nil { + return fmt.Errorf("encoding response: %w", err) + } + out = append(out, '\n') + if err := writeOutput(mapArgs.output, out); err != nil { + return err + } + } +} + +// readMachineKeyFile reads a machine private key from a file. +func readMachineKeyFile(path string) (key.MachinePrivate, error) { + data, err := os.ReadFile(path) + if err != nil { + return key.MachinePrivate{}, err + } + var k key.MachinePrivate + if err := k.UnmarshalText(bytes.TrimSpace(data)); err != nil { + return key.MachinePrivate{}, fmt.Errorf("parsing machine key from %q: %w", path, err) + } + return k, nil +} + +// readNodeKeyFile reads a node private key from a file. +func readNodeKeyFile(path string) (key.NodePrivate, error) { + data, err := os.ReadFile(path) + if err != nil { + return key.NodePrivate{}, err + } + var k key.NodePrivate + if err := k.UnmarshalText(bytes.TrimSpace(data)); err != nil { + return key.NodePrivate{}, fmt.Errorf("parsing node key from %q: %w", path, err) + } + return k, nil +} + +// readControlKeyFile reads a file containing a server's MachinePublic key +// in its MarshalText form (e.g. "mkey:..."). +func readControlKeyFile(path string) (key.MachinePublic, error) { + data, err := os.ReadFile(path) + if err != nil { + return key.MachinePublic{}, err + } + var k key.MachinePublic + if err := k.UnmarshalText(bytes.TrimSpace(data)); err != nil { + return key.MachinePublic{}, fmt.Errorf("parsing control key from %q: %w", path, err) + } + return k, nil +} + +// resolveAuthKey returns the auth key from v. If v is empty, it returns "". +// If v starts with "tskey-", it's used directly. Otherwise v is treated as a +// filename and its contents are read and trimmed. +func resolveAuthKey(v string) (string, error) { + if v == "" { + return "", nil + } + if strings.HasPrefix(strings.TrimSpace(v), "tskey-") { + return strings.TrimSpace(v), nil + } + data, err := os.ReadFile(v) + if err != nil { + return "", fmt.Errorf("reading auth key file: %w", err) + } + return strings.TrimSpace(string(data)), nil +} + +func writeOutput(path string, data []byte) error { + if path == "" { + _, err := os.Stdout.Write(data) + return err + } + return os.WriteFile(path, data, 0600) +} + +// isZeroMapResponse reports whether all fields of resp are zero values. +func isZeroMapResponse(resp *tailcfg.MapResponse) bool { + v := reflect.ValueOf(*resp) + for i := range v.NumField() { + if !v.Field(i).IsZero() { + return false + } + } + return true +} diff --git a/cmd/tta/bypass_linux.go b/cmd/tta/bypass_linux.go new file mode 100644 index 0000000000000..868cd716f8b0f --- /dev/null +++ b/cmd/tta/bypass_linux.go @@ -0,0 +1,39 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "fmt" + "syscall" + + "golang.org/x/sys/unix" + "tailscale.com/net/netmon" +) + +// bypassControlFunc is set as net.Dialer.Control so that sockets dialed by +// TTA bypass tailscaled's policy routing. Without it, sockets opened before +// tailscaled installs an exit-node route would have their packets rerouted +// via the exit node when the route is later installed, breaking the +// existing connection. +// +// We bind the socket to the default route's interface (typically the VM's +// LAN-facing NIC) rather than relying on the bypass fwmark. The fwmark +// approach is conditional on tailscaled having configured SO_MARK-based +// policy routing; binding to the underlying interface is unconditional. +func bypassControlFunc(network, address string, c syscall.RawConn) error { + ifc, err := netmon.DefaultRouteInterface() + if err != nil { + return fmt.Errorf("netmon.DefaultRouteInterface: %w", err) + } + var sockErr error + if err := c.Control(func(fd uintptr) { + sockErr = unix.SetsockoptString(int(fd), unix.SOL_SOCKET, unix.SO_BINDTODEVICE, ifc) + }); err != nil { + return err + } + if sockErr != nil { + return fmt.Errorf("setting SO_BINDTODEVICE on %q: %w", ifc, sockErr) + } + return nil +} diff --git a/cmd/tta/bypass_other.go b/cmd/tta/bypass_other.go new file mode 100644 index 0000000000000..e6b453f49633f --- /dev/null +++ b/cmd/tta/bypass_other.go @@ -0,0 +1,14 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux + +package main + +import "syscall" + +// bypassControlFunc is a no-op on non-Linux platforms; SO_MARK is a Linux +// concept and exit-node routing only matters here for Linux VMs in vmtest. +func bypassControlFunc(network, address string, c syscall.RawConn) error { + return nil +} diff --git a/cmd/tta/fw_linux.go b/cmd/tta/fw_linux.go index 49d8d41ea4b4d..66888a45b30c9 100644 --- a/cmd/tta/fw_linux.go +++ b/cmd/tta/fw_linux.go @@ -8,7 +8,6 @@ import ( "github.com/google/nftables" "github.com/google/nftables/expr" - "tailscale.com/types/ptr" ) func init() { @@ -35,7 +34,7 @@ func addFirewallLinux() error { Type: nftables.ChainTypeFilter, Hooknum: nftables.ChainHookInput, Priority: nftables.ChainPriorityFilter, - Policy: ptr.To(nftables.ChainPolicyDrop), + Policy: new(nftables.ChainPolicyDrop), } c.AddChain(inputChain) diff --git a/cmd/tta/ipassign_darwin.go b/cmd/tta/ipassign_darwin.go new file mode 100644 index 0000000000000..69a178956736d --- /dev/null +++ b/cmd/tta/ipassign_darwin.go @@ -0,0 +1,135 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build darwin + +package main + +import ( + "encoding/json" + "fmt" + "log" + "net" + "os/exec" + "strconv" + "time" + "unsafe" + + "golang.org/x/sys/unix" + "tailscale.com/tstest/natlab/vnet" +) + +const ( + afVSOCK = 40 // AF_VSOCK on macOS + vmaddrCIDHost = 2 // VMADDR_CID_HOST + vsockPort = 51011 // port for IP assignment protocol +) + +// sockaddrVM is the Go equivalent of struct sockaddr_vm from . +type sockaddrVM struct { + Len uint8 + Family uint8 + Reserved1 uint16 + Port uint32 + CID uint32 +} + +type netConfig struct { + IP string `json:"ip"` + Mask string `json:"mask"` + GW string `json:"gw"` +} + +// startIPAssignLoop starts a background goroutine that polls the host +// via the virtio socket for an IP assignment. When the host responds +// with a JSON config (rather than "wait"), TTA sets the IP statically +// using ifconfig and stops polling. +func startIPAssignLoop() { + go ipAssignLoop() +} + +func ipAssignLoop() { + log.Printf("ipassign: starting vsock poll loop") + var lastErr string + for attempt := 0; ; attempt++ { + resp, err := askHostForIP() + if err != nil { + if e := err.Error(); e != lastErr { + log.Printf("ipassign: attempt %d: %v", attempt, err) + lastErr = e + } + time.Sleep(500 * time.Millisecond) + continue + } + if resp == "wait" { + time.Sleep(500 * time.Millisecond) + continue + } + var nc netConfig + if err := json.Unmarshal([]byte(resp), &nc); err != nil { + log.Printf("ipassign: bad config: %v", err) + time.Sleep(500 * time.Millisecond) + continue + } + if err := setStaticIP(nc); err != nil { + log.Printf("ipassign: %v", err) + time.Sleep(500 * time.Millisecond) + continue + } + log.Printf("ipassign: configured en0 with %s/%s gw %s", nc.IP, nc.Mask, nc.GW) + + // Switch the driver address from the DNS name to the IP directly + // (avoids DNS resolution delay) and kick the dial-out loop so it + // retries immediately with the new address. + ipAddr := net.JoinHostPort(vnet.TestDriverIPv4().String(), strconv.Itoa(vnet.TestDriverPort)) + *driverAddr = ipAddr + log.Printf("ipassign: switched driver addr to %s", ipAddr) + resetDialCancels() + return + } +} + +// askHostForIP connects to the host via AF_VSOCK and reads the response. +func askHostForIP() (string, error) { + fd, err := unix.Socket(afVSOCK, unix.SOCK_STREAM, 0) + if err != nil { + return "", fmt.Errorf("socket: %w", err) + } + defer unix.Close(fd) + + // Set a short connect+read timeout via SO_RCVTIMEO. + tv := unix.Timeval{Sec: 1} + unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv) + + addr := sockaddrVM{ + Len: uint8(unsafe.Sizeof(sockaddrVM{})), + Family: afVSOCK, + Port: vsockPort, + CID: vmaddrCIDHost, + } + _, _, errno := unix.RawSyscall(unix.SYS_CONNECT, uintptr(fd), + uintptr(unsafe.Pointer(&addr)), unsafe.Sizeof(addr)) + if errno != 0 { + return "", fmt.Errorf("connect: %w", errno) + } + + var buf [1024]byte + n, err := unix.Read(fd, buf[:]) + if err != nil { + return "", fmt.Errorf("read: %w", err) + } + return string(buf[:n]), nil +} + +// setStaticIP configures en0 with a static IP address and default route. +func setStaticIP(nc netConfig) error { + out, err := exec.Command("ifconfig", "en0", nc.IP, "netmask", nc.Mask, "up").CombinedOutput() + if err != nil { + return fmt.Errorf("ifconfig: %v: %s", err, out) + } + out, err = exec.Command("route", "add", "default", nc.GW).CombinedOutput() + if err != nil { + return fmt.Errorf("route add: %v: %s", err, out) + } + return nil +} diff --git a/cmd/tta/ipassign_other.go b/cmd/tta/ipassign_other.go new file mode 100644 index 0000000000000..dc331b5e0c3f4 --- /dev/null +++ b/cmd/tta/ipassign_other.go @@ -0,0 +1,14 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !darwin + +package main + +// startIPAssignLoop is a no-op on non-macOS platforms. +// macOS VMs use vsock-based IP assignment to bypass slow DHCP. +func startIPAssignLoop() {} + +// Reference resetDialCancels to prevent unused-function lint errors. +// It's called from ipassign_darwin.go on macOS builds. +var _ = resetDialCancels diff --git a/cmd/tta/restart_tailscaled_linux.go b/cmd/tta/restart_tailscaled_linux.go new file mode 100644 index 0000000000000..accf2b404076a --- /dev/null +++ b/cmd/tta/restart_tailscaled_linux.go @@ -0,0 +1,47 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "fmt" + "os" + "strconv" + "strings" +) + +func init() { + restartTailscaled = restartTailscaledLinux +} + +// restartTailscaledLinux finds the tailscaled process by walking /proc and +// sends it SIGKILL. On gokrazy, the supervisor will restart tailscaled within +// a few seconds. The PID of the process that was killed is returned. +func restartTailscaledLinux() (int, error) { + ents, err := os.ReadDir("/proc") + if err != nil { + return 0, err + } + for _, e := range ents { + pid, err := strconv.Atoi(e.Name()) + if err != nil { + continue + } + comm, err := os.ReadFile("/proc/" + e.Name() + "/comm") + if err != nil { + continue + } + if strings.TrimSpace(string(comm)) != "tailscaled" { + continue + } + proc, err := os.FindProcess(pid) + if err != nil { + return 0, err + } + if err := proc.Kill(); err != nil { + return 0, fmt.Errorf("killing tailscaled pid %d: %w", pid, err) + } + return pid, nil + } + return 0, fmt.Errorf("tailscaled process not found in /proc") +} diff --git a/cmd/tta/tta.go b/cmd/tta/tta.go index 377d01c9487f7..5dd1eddb9fedd 100644 --- a/cmd/tta/tta.go +++ b/cmd/tta/tta.go @@ -15,6 +15,7 @@ import ( "context" "errors" "flag" + "fmt" "io" "log" "net" @@ -23,7 +24,9 @@ import ( "net/url" "os" "os/exec" + "path/filepath" "regexp" + "runtime" "strconv" "strings" "sync" @@ -38,6 +41,17 @@ import ( "tailscale.com/version/distro" ) +// connContextKeyType is the type of connContextKey, which isn't of type +// `string` to avoid collisions while being used as a context key. +type connContextKeyType string + +const ( + // connContextKey is the key for looking up the TCP connection + // corresponding to an HTTP request coming in from testing + // infrastructure. + connContextKey connContextKeyType = "conn-context-key" +) + var ( driverAddr = flag.String("driver", "test-driver.tailscale:8008", "address of the test driver; by default we use the DNS name test-driver.tailscale which is special cased in the emulated network's DNS server") ) @@ -55,9 +69,13 @@ func serveCmd(w http.ResponseWriter, cmd string, args ...string) { w.Header().Set("Content-Type", "text/plain; charset=utf-8") if err != nil { w.Header().Set("Exec-Err", err.Error()) + if exiterr, ok := err.(*exec.ExitError); ok { + w.Header().Set("Exec-Exit-Code", strconv.Itoa(exiterr.ExitCode())) + } w.WriteHeader(500) log.Printf("Err on serveCmd for %q %v, %d bytes of output: %v", cmd, args, len(out), err) } else { + w.Header().Set("Exec-Exit-Code", "0") log.Printf("Did serveCmd for %q %v, %d bytes of output", cmd, args, len(out)) } w.Write(out) @@ -87,11 +105,15 @@ func main() { } flag.Parse() + // On macOS VMs, start polling the host via vsock for an IP assignment. + // This bypasses DHCP for near-instant network configuration. + startIPAssignLoop() + debug := false if distro.Get() == distro.Gokrazy { cmdLine, _ := os.ReadFile("/proc/cmdline") explicitNS := false - for _, s := range strings.Fields(string(cmdLine)) { + for s := range strings.FieldsSeq(string(cmdLine)) { if ns, ok := strings.CutPrefix(s, "tta.nameserver="); ok { err := atomicfile.WriteFile("/tmp/resolv.conf", []byte("nameserver "+ns+"\n"), 0644) log.Printf("Wrote /tmp/resolv.conf: %v", err) @@ -139,8 +161,12 @@ func main() { } ttaMux.ServeHTTP(w, r) }) + var hs http.Server hs.Handler = &serveMux + hs.ConnContext = func(ctx context.Context, c net.Conn) context.Context { + return context.WithValue(ctx, connContextKey, c) + } revSt := revDialState{ needConnCh: make(chan bool, 1), debug: debug, @@ -162,9 +188,221 @@ func main() { return }) ttaMux.HandleFunc("/up", func(w http.ResponseWriter, r *http.Request) { - serveCmd(w, "tailscale", "up", "--login-server=http://control.tailscale") + args := []string{"up", "--login-server=http://control.tailscale"} + if routes := r.URL.Query().Get("advertise-routes"); routes != "" { + args = append(args, "--advertise-routes="+routes) + } + if snat := r.URL.Query().Get("snat-subnet-routes"); snat != "" { + args = append(args, "--snat-subnet-routes="+snat) + } + if r.URL.Query().Get("accept-routes") == "true" { + args = append(args, "--accept-routes") + } + serveCmd(w, "tailscale", args...) + }) + ttaMux.HandleFunc("/set", func(w http.ResponseWriter, r *http.Request) { + args := []string{"set"} + if r.URL.Query().Get("accept-routes") == "true" { + args = append(args, "--accept-routes") + } + if routes := r.URL.Query().Get("advertise-routes"); routes != "" { + args = append(args, "--advertise-routes="+routes) + } + if snat := r.URL.Query().Get("snat-subnet-routes"); snat != "" { + args = append(args, "--snat-subnet-routes="+snat) + } + serveCmd(w, "tailscale", args...) + }) + ttaMux.HandleFunc("/ip", func(w http.ResponseWriter, r *http.Request) { + conn, ok := r.Context().Value(connContextKey).(net.Conn) + if !ok { + w.WriteHeader(http.StatusInternalServerError) + return + } + w.Write([]byte(conn.LocalAddr().String())) + }) + ttaMux.HandleFunc("/ping", func(w http.ResponseWriter, r *http.Request) { + host := r.URL.Query().Get("host") + if distro.Get() == distro.Gokrazy { + // The busybox in question here is the breakglass busybox inside the + // natlab QEMU image. + serveCmd(w, "/usr/local/bin/busybox", "ping", "-c", "4", "-W", "1", host) + } else { + serveCmd(w, "ping", "-c", "4", "-W", "1", host) + } + }) + ttaMux.HandleFunc("/add-route", func(w http.ResponseWriter, r *http.Request) { + prefix := r.URL.Query().Get("prefix") + via := r.URL.Query().Get("via") + if prefix == "" || via == "" { + http.Error(w, "missing prefix or via", http.StatusBadRequest) + return + } + switch runtime.GOOS { + case "linux": + serveCmd(w, "ip", "route", "add", prefix, "via", via) + default: + http.Error(w, "add-route not supported on "+runtime.GOOS, http.StatusNotImplemented) + } + }) + ttaMux.HandleFunc("/start-webserver", func(w http.ResponseWriter, r *http.Request) { + port := r.URL.Query().Get("port") + name := r.URL.Query().Get("name") + if port == "" { + http.Error(w, "missing port", http.StatusBadRequest) + return + } + if name == "" { + name = "unnamed" + } + log.Printf("Starting webserver on port %s as %q", port, name) + go func() { + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + host, _, _ := net.SplitHostPort(r.RemoteAddr) + fmt.Fprintf(w, "Hello world I am %s from %s", name, host) + }) + if err := http.ListenAndServe(":"+port, mux); err != nil { + log.Printf("webserver on :%s failed: %v", port, err) + } + }() + io.WriteString(w, "OK\n") + }) + ttaMux.HandleFunc("/taildrop-send", func(w http.ResponseWriter, r *http.Request) { + to := r.URL.Query().Get("to") // peer's Tailscale IP + name := r.URL.Query().Get("name") + if to == "" || name == "" { + http.Error(w, "missing to or name", http.StatusBadRequest) + return + } + if strings.ContainsAny(name, "/\\") { + http.Error(w, "bad name", http.StatusBadRequest) + return + } + dir, err := os.MkdirTemp("", "taildrop-send-") + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer os.RemoveAll(dir) + path := filepath.Join(dir, name) + f, err := os.Create(path) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if _, err := io.Copy(f, r.Body); err != nil { + f.Close() + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if err := f.Close(); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + serveCmd(w, "tailscale", "file", "cp", path, to+":") + }) + ttaMux.HandleFunc("/taildrop-recv", func(w http.ResponseWriter, r *http.Request) { + dir, err := os.MkdirTemp("", "taildrop-recv-") + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer os.RemoveAll(dir) + ctx, cancel := context.WithTimeout(r.Context(), 60*time.Second) + defer cancel() + cmd := exec.CommandContext(ctx, absify("tailscale"), "file", "get", "--wait", dir) + if out, err := cmd.CombinedOutput(); err != nil { + http.Error(w, fmt.Sprintf("tailscale file get: %v\n%s", err, out), http.StatusInternalServerError) + return + } + ents, err := os.ReadDir(dir) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if len(ents) != 1 { + http.Error(w, fmt.Sprintf("got %d files, want 1", len(ents)), http.StatusInternalServerError) + return + } + data, err := os.ReadFile(filepath.Join(dir, ents[0].Name())) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Header().Set("Taildrop-Filename", ents[0].Name()) + w.Header().Set("Content-Type", "application/octet-stream") + w.Write(data) + }) + ttaMux.HandleFunc("/http-get", func(w http.ResponseWriter, r *http.Request) { + targetURL := r.URL.Query().Get("url") + if targetURL == "" { + http.Error(w, "missing url", http.StatusBadRequest) + return + } + log.Printf("HTTP GET %s", targetURL) + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + defer cancel() + req, err := http.NewRequestWithContext(ctx, "GET", targetURL, nil) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + // Use Tailscale's SOCKS5 proxy if available, so traffic to Tailscale + // subnet routes goes through the WireGuard tunnel instead of the + // host network stack (which may not have the routes, especially + // in userspace networking mode). + client := &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + // Try the Tailscale localapi proxy dialer first. + host, portStr, err := net.SplitHostPort(addr) + if err != nil { + var d net.Dialer + return d.DialContext(ctx, network, addr) + } + port, _ := strconv.ParseUint(portStr, 10, 16) + var lc local.Client + conn, err := lc.UserDial(ctx, network, host, uint16(port)) + if err == nil { + return conn, nil + } + log.Printf("http-get: UserDial failed, falling back to direct: %v", err) + var d net.Dialer + return d.DialContext(ctx, network, addr) + }, + }, + } + resp, err := client.Do(req) + if err != nil { + http.Error(w, err.Error(), http.StatusBadGateway) + return + } + defer resp.Body.Close() + w.Header().Set("X-Upstream-Status", strconv.Itoa(resp.StatusCode)) + w.WriteHeader(resp.StatusCode) + io.Copy(w, resp.Body) }) ttaMux.HandleFunc("/fw", addFirewallHandler) + ttaMux.HandleFunc("/wg-server-up", func(w http.ResponseWriter, r *http.Request) { + if wgServerUp == nil { + http.Error(w, "wg-server-up not supported on this platform", http.StatusNotImplemented) + return + } + wgServerUp(w, r) + }) + ttaMux.HandleFunc("/restart-tailscaled", func(w http.ResponseWriter, r *http.Request) { + if restartTailscaled == nil { + http.Error(w, "restart-tailscaled not supported on this platform", http.StatusNotImplemented) + return + } + pid, err := restartTailscaled() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + fmt.Fprintf(w, "killed tailscaled pid %d (supervisor will respawn)\n", pid) + }) ttaMux.HandleFunc("/logs", func(w http.ResponseWriter, r *http.Request) { logBuf.mu.Lock() defer logBuf.mu.Unlock() @@ -186,10 +424,48 @@ func main() { revSt.runDialOutLoop(conns) } +// dialCancels tracks cancel funcs for in-flight connect() and sleep contexts. +// resetDialCancels cancels them all so the dial loop retries immediately. +var ( + dialCancelMu sync.Mutex + dialCancels set.HandleSet[context.CancelFunc] +) + +// registerDialCancel adds a cancel func and returns a handle for removal. +func registerDialCancel(cancel context.CancelFunc) set.Handle { + dialCancelMu.Lock() + defer dialCancelMu.Unlock() + return dialCancels.Add(cancel) +} + +// unregisterDialCancel removes a previously registered cancel func. +func unregisterDialCancel(h set.Handle) { + dialCancelMu.Lock() + defer dialCancelMu.Unlock() + delete(dialCancels, h) +} + +// resetDialCancels cancels all in-flight connect and sleep contexts, +// causing the dial loop to retry immediately with the updated driver address. +func resetDialCancels() { + dialCancelMu.Lock() + defer dialCancelMu.Unlock() + for h, cancel := range dialCancels { + cancel() + delete(dialCancels, h) + } +} + func connect() (net.Conn, error) { - var d net.Dialer + d := net.Dialer{ + Control: bypassControlFunc, + } ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() + h := registerDialCancel(cancel) + defer func() { + cancel() + unregisterDialCancel(h) + }() c, err := d.DialContext(ctx, "tcp", *driverAddr) if err != nil { return nil, err @@ -286,7 +562,11 @@ func (s *revDialState) runDialOutLoop(conns chan<- net.Conn) { log.Printf("[dial-driver] connect failure: %v", s) } lastErr = s - time.Sleep(time.Second) + sleepCtx, sleepCancel := context.WithTimeout(context.Background(), time.Second) + h := registerDialCancel(sleepCancel) + <-sleepCtx.Done() + sleepCancel() + unregisterDialCancel(h) continue } if !connected { @@ -327,6 +607,16 @@ func addFirewallHandler(w http.ResponseWriter, r *http.Request) { var addFirewall func() error // set by fw_linux.go +// wgServerUp brings up a userspace WireGuard "Mullvad-style" exit-node +// server on this VM. It is set by wgserver_linux.go and is nil on +// non-Linux. +var wgServerUp func(w http.ResponseWriter, r *http.Request) + +// restartTailscaled sends SIGKILL to the local tailscaled process so the +// gokrazy supervisor restarts it. It is set by restart_tailscaled_linux.go +// and is nil on non-Linux. +var restartTailscaled func() (pid int, err error) + // logBuffer is a bytes.Buffer that is safe for concurrent use // intended to capture early logs from the process, even if // gokrazy's syslog streaming isn't working or yet working. diff --git a/cmd/tta/wgserver_linux.go b/cmd/tta/wgserver_linux.go new file mode 100644 index 0000000000000..10d6bbe282ade --- /dev/null +++ b/cmd/tta/wgserver_linux.go @@ -0,0 +1,155 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "cmp" + "crypto/rand" + "encoding/base64" + "encoding/hex" + "fmt" + "log" + "net/http" + "os" + "os/exec" + "sync" + + "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/device" + "github.com/tailscale/wireguard-go/tun" + "golang.org/x/crypto/curve25519" + "tailscale.com/wgengine/wgcfg" +) + +func init() { + wgServerUp = wgServerUpLinux +} + +var ( + wgServerMu sync.Mutex + wgServerDev *device.Device // retained so the goroutines stay alive +) + +// wgServerUpLinux brings up a userspace WireGuard interface on the local VM +// configured as a single-peer "Mullvad-style" exit node, then sets up the +// kernel-side IP/forwarding/MASQUERADE so that decrypted traffic from the +// peer egresses to the test internet. +// +// Required URL query parameters: +// - addr: CIDR for the WG interface (e.g. "10.64.0.1/24") +// - listen-port: WG listen port +// - peer-pub-b64: base64-encoded 32-byte WG public key of the only peer +// - peer-allowed-ip: prefix the peer is allowed to source from +// (e.g. "10.64.0.2/32") +// - masq-src: prefix to MASQUERADE on egress (e.g. "10.64.0.0/24") +// +// Optional: +// - name: TUN device name (default "wg0") +// +// On success, it writes "PUBKEY=\n" — the freshly generated public +// key the caller must pin as the peer's WG public key. +func wgServerUpLinux(w http.ResponseWriter, r *http.Request) { + wgServerMu.Lock() + defer wgServerMu.Unlock() + if wgServerDev != nil { + http.Error(w, "wg server already up", http.StatusConflict) + return + } + + q := r.URL.Query() + name := cmp.Or(q.Get("name"), "wg0") + addr := q.Get("addr") + listenPort := q.Get("listen-port") + peerPubB64 := q.Get("peer-pub-b64") + peerAllowedIP := q.Get("peer-allowed-ip") + masqSrc := q.Get("masq-src") + for _, kv := range []struct{ k, v string }{ + {"addr", addr}, + {"listen-port", listenPort}, + {"peer-pub-b64", peerPubB64}, + {"peer-allowed-ip", peerAllowedIP}, + {"masq-src", masqSrc}, + } { + if kv.v == "" { + http.Error(w, "missing "+kv.k, http.StatusBadRequest) + return + } + } + + peerPub, err := base64.StdEncoding.DecodeString(peerPubB64) + if err != nil || len(peerPub) != 32 { + http.Error(w, fmt.Sprintf("bad peer-pub-b64: %v (len=%d)", err, len(peerPub)), http.StatusBadRequest) + return + } + + var priv [32]byte + if _, err := rand.Read(priv[:]); err != nil { + http.Error(w, "rand: "+err.Error(), http.StatusInternalServerError) + return + } + // X25519 key clamping. + priv[0] &= 248 + priv[31] = (priv[31] & 127) | 64 + + pub, err := curve25519.X25519(priv[:], curve25519.Basepoint) + if err != nil { + http.Error(w, "deriving pubkey: "+err.Error(), http.StatusInternalServerError) + return + } + + tdev, err := tun.CreateTUN(name, device.DefaultMTU) + if err != nil { + http.Error(w, "tun.CreateTUN: "+err.Error(), http.StatusInternalServerError) + return + } + wglog := &device.Logger{ + Verbosef: func(string, ...any) {}, + Errorf: func(f string, a ...any) { log.Printf("wg-server: "+f, a...) }, + } + dev := wgcfg.NewDevice(tdev, conn.NewDefaultBind(), wglog) + + uapi := fmt.Sprintf("private_key=%s\nlisten_port=%s\npublic_key=%s\nallowed_ip=%s\n", + hex.EncodeToString(priv[:]), listenPort, + hex.EncodeToString(peerPub), peerAllowedIP) + if err := dev.IpcSet(uapi); err != nil { + dev.Close() + http.Error(w, "IpcSet: "+err.Error(), http.StatusInternalServerError) + return + } + if err := dev.Up(); err != nil { + dev.Close() + http.Error(w, "dev.Up: "+err.Error(), http.StatusInternalServerError) + return + } + + steps := []struct { + why string + exec []string + file struct{ path, data string } + }{ + {why: "ip addr add", exec: []string{"ip", "addr", "add", addr, "dev", name}}, + {why: "ip link up", exec: []string{"ip", "link", "set", name, "up"}}, + {why: "enable forwarding", file: struct{ path, data string }{"/proc/sys/net/ipv4/ip_forward", "1\n"}}, + {why: "FORWARD policy", exec: []string{"iptables", "-P", "FORWARD", "ACCEPT"}}, + {why: "MASQUERADE", exec: []string{"iptables", "-t", "nat", "-A", "POSTROUTING", "-s", masqSrc, "-j", "MASQUERADE"}}, + } + for _, s := range steps { + if s.file.path != "" { + if err := os.WriteFile(s.file.path, []byte(s.file.data), 0644); err != nil { + dev.Close() + http.Error(w, fmt.Sprintf("%s: %v", s.why, err), http.StatusInternalServerError) + return + } + continue + } + if out, err := exec.Command(s.exec[0], s.exec[1:]...).CombinedOutput(); err != nil { + dev.Close() + http.Error(w, fmt.Sprintf("%s: %v: %s", s.why, err, out), http.StatusInternalServerError) + return + } + } + + wgServerDev = dev + fmt.Fprintf(w, "PUBKEY=%s\n", base64.StdEncoding.EncodeToString(pub)) +} diff --git a/cmd/vet/lowerell/analyzer.go b/cmd/vet/lowerell/analyzer.go new file mode 100644 index 0000000000000..a62f79bdcc6bb --- /dev/null +++ b/cmd/vet/lowerell/analyzer.go @@ -0,0 +1,132 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// Package lowerell forbids variables named "l" (lowercase ell) or "I" +// (uppercase i), because they are hard to distinguish from the digit +// "1" and from each other in too many fonts. +package lowerell + +import ( + "go/ast" + "go/token" + + "golang.org/x/tools/go/analysis" +) + +// Analyzer reports variables named "l" (lowercase ell) or "I" (uppercase i). +var Analyzer = &analysis.Analyzer{ + Name: "lowerell", + Doc: `forbid variables named "l" (lowercase ell) or "I" (uppercase i), which are hard to distinguish from "1"`, + Run: run, +} + +// messages maps a banned identifier name to the diagnostic shown to users. +// Each message names the specific symbol that triggered it, so the +// reader does not have to guess which of "l" or "I" they typed. +var messages = map[string]string{ + "l": `do not use "l" (lowercase ell) as a variable name; it is hard to distinguish from "1" and "I" in too many fonts; see https://github.com/tailscale/tailscale/issues/19631`, + "I": `do not use "I" (uppercase i) as a variable name; it is hard to distinguish from "1" and "l" in too many fonts; see https://github.com/tailscale/tailscale/issues/19631`, +} + +// reported tracks identifier positions already reported, to avoid duplicate +// diagnostics when the same declaration is reachable from multiple AST nodes. +type reportedSet map[token.Pos]bool + +func (rs reportedSet) check(pass *analysis.Pass, ident *ast.Ident) { + if ident == nil { + return + } + msg, ok := messages[ident.Name] + if !ok { + return + } + if rs[ident.Pos()] { + return + } + rs[ident.Pos()] = true + pass.Reportf(ident.Pos(), "%s", msg) +} + +func (rs reportedSet) checkFieldList(pass *analysis.Pass, fl *ast.FieldList) { + if fl == nil { + return + } + for _, f := range fl.List { + for _, n := range f.Names { + rs.check(pass, n) + } + } +} + +func run(pass *analysis.Pass) (any, error) { + rs := reportedSet{} + + for _, file := range pass.Files { + ast.Inspect(file, func(n ast.Node) bool { + switch n := n.(type) { + case *ast.FuncDecl: + // Receiver name. + rs.checkFieldList(pass, n.Recv) + // Parameters, results, and type parameters + // are checked via the FuncType case below. + + case *ast.FuncType: + rs.checkFieldList(pass, n.TypeParams) + rs.checkFieldList(pass, n.Params) + rs.checkFieldList(pass, n.Results) + + case *ast.StructType: + rs.checkFieldList(pass, n.Fields) + + case *ast.GenDecl: + if n.Tok != token.VAR && n.Tok != token.CONST { + return true + } + for _, spec := range n.Specs { + vs, ok := spec.(*ast.ValueSpec) + if !ok { + continue + } + for _, name := range vs.Names { + rs.check(pass, name) + } + } + + case *ast.AssignStmt: + if n.Tok != token.DEFINE { + return true + } + for _, lhs := range n.Lhs { + if id, ok := lhs.(*ast.Ident); ok { + rs.check(pass, id) + } + } + + case *ast.RangeStmt: + if n.Tok != token.DEFINE { + return true + } + if id, ok := n.Key.(*ast.Ident); ok { + rs.check(pass, id) + } + if id, ok := n.Value.(*ast.Ident); ok { + rs.check(pass, id) + } + + case *ast.TypeSwitchStmt: + // switch l := x.(type) { ... } + as, ok := n.Assign.(*ast.AssignStmt) + if !ok || as.Tok != token.DEFINE { + return true + } + for _, lhs := range as.Lhs { + if id, ok := lhs.(*ast.Ident); ok { + rs.check(pass, id) + } + } + } + return true + }) + } + return nil, nil +} diff --git a/cmd/vet/lowerell/analyzer_test.go b/cmd/vet/lowerell/analyzer_test.go new file mode 100644 index 0000000000000..c566c2ec43ede --- /dev/null +++ b/cmd/vet/lowerell/analyzer_test.go @@ -0,0 +1,15 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package lowerell + +import ( + "testing" + + "golang.org/x/tools/go/analysis/analysistest" +) + +func TestAnalyzer(t *testing.T) { + testdata := analysistest.TestData() + analysistest.Run(t, testdata, Analyzer, "example") +} diff --git a/cmd/vet/lowerell/testdata/src/example/example.go b/cmd/vet/lowerell/testdata/src/example/example.go new file mode 100644 index 0000000000000..c67c1978135ef --- /dev/null +++ b/cmd/vet/lowerell/testdata/src/example/example.go @@ -0,0 +1,100 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package example + +import "sync" + +// Bad: var declarations. +var l int // want `do not use "l"` +var I int // want `do not use "I"` + +// OK: variables named "ll", "II", "i" are fine. +var ( + ll int + II int + i int +) + +// Bad: const declaration in a function scope. +func F0() { + const l = 3 // want `do not use "l"` + const I = 4 // want `do not use "I"` + _ = l + _ = I +} + +// Bad: function parameters. +func F1a(l int) {} // want `do not use "l"` +func F1b(I int) {} // want `do not use "I"` + +// Bad: named return values. +func F2a() (l int) { return } // want `do not use "l"` +func F2b() (I int) { return } // want `do not use "I"` + +// Bad: receiver names. +type T struct{} + +func (l *T) Ml() {} // want `do not use "l"` +func (I *T) MI() {} // want `do not use "I"` + +// Bad: struct fields. +type S struct { + l int // want `do not use "l"` + I int // want `do not use "I"` +} + +// Bad: short variable declarations. +func F3() { + l := 1 // want `do not use "l"` + I := 2 // want `do not use "I"` + _ = l + _ = I +} + +// Bad: var statement inside a function. +func F4() { + var l int // want `do not use "l"` + var I int // want `do not use "I"` + _ = l + _ = I +} + +// Bad: range key/value. +func F5(xs []int) { + for l, v := range xs { // want `do not use "l"` + _ = l + _ = v + } + for _, I := range xs { // want `do not use "I"` + _ = I + } +} + +// Bad: type parameters. +func F6a[l any](x l) l { return x } // want `do not use "l"` +func F6b[I any](x I) I { return x } // want `do not use "I"` + +// Bad: type switch guards. +func F7(x any) { + switch l := x.(type) { // want `do not use "l"` + case int: + _ = l + } + switch I := x.(type) { // want `do not use "I"` + case int: + _ = I + } +} + +// OK: clean code with no banned variables. +func F8() { + count := 0 + for i := 0; i < 10; i++ { + count++ + } + _ = count +} + +// OK: sync.Mutex named "mu". +var mu sync.Mutex diff --git a/cmd/vet/subtestnames/analyzer.go b/cmd/vet/subtestnames/analyzer.go new file mode 100644 index 0000000000000..d37a992e59656 --- /dev/null +++ b/cmd/vet/subtestnames/analyzer.go @@ -0,0 +1,227 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// Package subtestnames checks that t.Run subtest names don't contain characters +// that require quoting or escaping when re-running via "go test -run". +// +// Go's testing package rewrites subtest names: spaces become underscores, +// non-printable characters get escaped, and regex metacharacters require +// escaping in -run patterns. This makes it hard to re-run specific subtests +// or search for them in code. +// +// This analyzer flags: +// - Direct t.Run calls with string literal names containing bad characters +// - t.Run calls using tt.name (or similar) where tt ranges over a slice/map +// of test cases with string literal names containing bad characters +package subtestnames + +import ( + "go/ast" + "go/token" + "strconv" + "strings" + "unicode" + + "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/analysis/passes/inspect" + "golang.org/x/tools/go/ast/inspector" +) + +// Analyzer checks that t.Run subtest names are clean for use with "go test -run". +var Analyzer = &analysis.Analyzer{ + Name: "subtestnames", + Doc: "check that t.Run subtest names don't require quoting when re-running via go test -run", + Requires: []*analysis.Analyzer{inspect.Analyzer}, + Run: run, +} + +// badChars are characters that are problematic in subtest names. +// Spaces are rewritten to underscores by testing.rewrite, and regex +// metacharacters require escaping in -run patterns. +const badChars = " \t\n\r^$.*+?()[]{}|\\'\"#" + +// hasBadChar reports whether s contains any character that would be +// problematic in a subtest name. +func hasBadChar(s string) bool { + return strings.ContainsAny(s, badChars) || strings.ContainsFunc(s, func(r rune) bool { + return !unicode.IsPrint(r) + }) +} + +// hasBadDash reports whether s starts or ends with a dash, which is +// problematic in subtest names because "go test -run" may interpret a +// leading dash as a flag, and trailing dashes are confusing. +func hasBadDash(s string) bool { + return strings.HasPrefix(s, "-") || strings.HasSuffix(s, "-") +} + +func run(pass *analysis.Pass) (any, error) { + insp := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) + + // Build a stack of enclosing nodes so we can find the RangeStmt + // enclosing a given t.Run call. + nodeFilter := []ast.Node{ + (*ast.RangeStmt)(nil), + (*ast.CallExpr)(nil), + } + + var rangeStack []*ast.RangeStmt + + insp.Nodes(nodeFilter, func(n ast.Node, push bool) bool { + switch n := n.(type) { + case *ast.RangeStmt: + if push { + rangeStack = append(rangeStack, n) + } else { + rangeStack = rangeStack[:len(rangeStack)-1] + } + return true + case *ast.CallExpr: + if !push { + return true + } + checkCallExpr(pass, n, rangeStack) + return true + } + return true + }) + + return nil, nil +} + +func checkCallExpr(pass *analysis.Pass, call *ast.CallExpr, rangeStack []*ast.RangeStmt) { + // Check if this is a t.Run(...) or b.Run(...) call. + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok || sel.Sel.Name != "Run" || len(call.Args) < 2 { + return + } + + // Verify the receiver is *testing.T, *testing.B, or *testing.F. + if !isTestingTBF(pass, sel) { + return + } + + nameArg := call.Args[0] + + // Case 1: Direct string literal, e.g. t.Run("foo bar", ...) + if lit, ok := nameArg.(*ast.BasicLit); ok && lit.Kind == token.STRING { + val, err := strconv.Unquote(lit.Value) + if err != nil { + return + } + if hasBadChar(val) { + pass.Reportf(lit.Pos(), "subtest name %s contains characters that require quoting in go test -run patterns", lit.Value) + } else if hasBadDash(val) { + pass.Reportf(lit.Pos(), "subtest name %s starts or ends with '-' which is problematic in go test -run patterns", lit.Value) + } + return + } + + // Case 2: Selector expression like tt.name, tc.name, etc. + // where tt is a range variable over a slice/map of test cases. + selExpr, ok := nameArg.(*ast.SelectorExpr) + if !ok { + return + } + ident, ok := selExpr.X.(*ast.Ident) + if !ok { + return + } + + // Find the enclosing range statement where ident is the value variable. + for i := len(rangeStack) - 1; i >= 0; i-- { + rs := rangeStack[i] + valIdent, ok := rs.Value.(*ast.Ident) + if !ok || valIdent.Obj != ident.Obj { + continue + } + // Found the range statement. Check the source being iterated. + checkRangeSource(pass, rs.X, selExpr.Sel) + return + } +} + +// isTestingTBF checks whether sel looks like a method call on *testing.T, *testing.B, or *testing.F. +func isTestingTBF(pass *analysis.Pass, sel *ast.SelectorExpr) bool { + typ := pass.TypesInfo.TypeOf(sel.X) + if typ != nil { + s := typ.String() + return s == "*testing.T" || s == "*testing.B" || s == "*testing.F" + } + return false +} + +// checkRangeSource examines the expression being ranged over and checks +// composite literal elements for bad subtest name fields. +func checkRangeSource(pass *analysis.Pass, rangeExpr ast.Expr, fieldName *ast.Ident) { + switch x := rangeExpr.(type) { + case *ast.Ident: + if x.Obj == nil { + return + } + switch decl := x.Obj.Decl.(type) { + case *ast.AssignStmt: + // e.g. tests := []struct{...}{...} + for _, rhs := range decl.Rhs { + checkCompositeLit(pass, rhs, fieldName) + } + case *ast.ValueSpec: + // e.g. var tests = []struct{...}{...} + for _, val := range decl.Values { + checkCompositeLit(pass, val, fieldName) + } + } + case *ast.CompositeLit: + checkCompositeLit(pass, x, fieldName) + } +} + +// checkCompositeLit checks a composite literal (slice/map) for elements +// that have a field with a bad subtest name. +func checkCompositeLit(pass *analysis.Pass, expr ast.Expr, fieldName *ast.Ident) { + comp, ok := expr.(*ast.CompositeLit) + if !ok { + return + } + + for _, elt := range comp.Elts { + // For map literals, check the value. + if kv, ok := elt.(*ast.KeyValueExpr); ok { + elt = kv.Value + } + checkStructLitField(pass, elt, fieldName) + } +} + +// checkStructLitField checks a struct literal for a field with the given name +// that contains a bad subtest name string. +func checkStructLitField(pass *analysis.Pass, expr ast.Expr, fieldName *ast.Ident) { + comp, ok := expr.(*ast.CompositeLit) + if !ok { + return + } + + for _, elt := range comp.Elts { + kv, ok := elt.(*ast.KeyValueExpr) + if !ok { + continue + } + key, ok := kv.Key.(*ast.Ident) + if !ok || key.Name != fieldName.Name { + continue + } + lit, ok := kv.Value.(*ast.BasicLit) + if !ok || lit.Kind != token.STRING { + continue + } + val, err := strconv.Unquote(lit.Value) + if err != nil { + continue + } + if hasBadChar(val) { + pass.Reportf(lit.Pos(), "subtest name %s contains characters that require quoting in go test -run patterns", lit.Value) + } else if hasBadDash(val) { + pass.Reportf(lit.Pos(), "subtest name %s starts or ends with '-' which is problematic in go test -run patterns", lit.Value) + } + } +} diff --git a/cmd/vet/subtestnames/analyzer_test.go b/cmd/vet/subtestnames/analyzer_test.go new file mode 100644 index 0000000000000..b051a0369c4f6 --- /dev/null +++ b/cmd/vet/subtestnames/analyzer_test.go @@ -0,0 +1,15 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package subtestnames + +import ( + "testing" + + "golang.org/x/tools/go/analysis/analysistest" +) + +func TestAnalyzer(t *testing.T) { + testdata := analysistest.TestData() + analysistest.Run(t, testdata, Analyzer, "example") +} diff --git a/cmd/vet/subtestnames/testdata/src/example/example_test.go b/cmd/vet/subtestnames/testdata/src/example/example_test.go new file mode 100644 index 0000000000000..f76599c324c5e --- /dev/null +++ b/cmd/vet/subtestnames/testdata/src/example/example_test.go @@ -0,0 +1,112 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package example + +import "testing" + +func TestDirect(t *testing.T) { + // Bad: spaces + t.Run("that everything's cool", func(t *testing.T) {}) // want `subtest name "that everything's cool" contains characters that require quoting` + + // Bad: apostrophe + t.Run("it's working", func(t *testing.T) {}) // want `subtest name "it's working" contains characters that require quoting` + + // Bad: regex metacharacters + t.Run("test(foo)", func(t *testing.T) {}) // want `subtest name "test\(foo\)" contains characters that require quoting` + t.Run("test[0]", func(t *testing.T) {}) // want `subtest name "test\[0\]" contains characters that require quoting` + t.Run("a|b", func(t *testing.T) {}) // want `subtest name "a\|b" contains characters that require quoting` + t.Run("a*b", func(t *testing.T) {}) // want `subtest name "a\*b" contains characters that require quoting` + t.Run("a+b", func(t *testing.T) {}) // want `subtest name "a\+b" contains characters that require quoting` + t.Run("a.b", func(t *testing.T) {}) // want `subtest name "a\.b" contains characters that require quoting` + t.Run("^start", func(t *testing.T) {}) // want `subtest name "\^start" contains characters that require quoting` + t.Run("end$", func(t *testing.T) {}) // want `subtest name "end\$" contains characters that require quoting` + t.Run("a{2}", func(t *testing.T) {}) // want `subtest name "a\{2\}" contains characters that require quoting` + t.Run("a?b", func(t *testing.T) {}) // want `subtest name "a\?b" contains characters that require quoting` + t.Run("a\\b", func(t *testing.T) {}) // want `subtest name "a\\\\b" contains characters that require quoting` + + // Bad: double quotes + t.Run("say \"hello\"", func(t *testing.T) {}) // want `subtest name "say \\"hello\\"" contains characters that require quoting` + + // Bad: hash + t.Run("comment#1", func(t *testing.T) {}) // want `subtest name "comment#1" contains characters that require quoting` + + // Bad: leading/trailing dash + t.Run("-leading-dash", func(t *testing.T) {}) // want `subtest name "-leading-dash" starts or ends with '-' which is problematic` + t.Run("trailing-dash-", func(t *testing.T) {}) // want `subtest name "trailing-dash-" starts or ends with '-' which is problematic` + t.Run("-both-", func(t *testing.T) {}) // want `subtest name "-both-" starts or ends with '-' which is problematic` + + // Good: clean names + t.Run("zero-passes", func(t *testing.T) {}) + t.Run("simple_test", func(t *testing.T) {}) + t.Run("CamelCase", func(t *testing.T) {}) + t.Run("with-dashes", func(t *testing.T) {}) + t.Run("123", func(t *testing.T) {}) + t.Run("comma,separated", func(t *testing.T) {}) + t.Run("colon:value", func(t *testing.T) {}) + t.Run("slash/path", func(t *testing.T) {}) + t.Run("equals=sign", func(t *testing.T) {}) +} + +func TestTableDriven(t *testing.T) { + tests := []struct { + name string + val int + }{ + {name: "bad space name", val: 1}, // want `subtest name "bad space name" contains characters that require quoting` + {name: "good-name", val: 2}, + {name: "also(bad)", val: 3}, // want `subtest name "also\(bad\)" contains characters that require quoting` + {name: "it's-bad", val: 4}, // want `subtest name "it's-bad" contains characters that require quoting` + {name: "clean-name", val: 5}, + {name: "-leading-dash", val: 6}, // want `subtest name "-leading-dash" starts or ends with '-' which is problematic` + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) {}) + } +} + +func TestTableDrivenVar(t *testing.T) { + var tests = []struct { + name string + val int + }{ + {name: "has spaces", val: 1}, // want `subtest name "has spaces" contains characters that require quoting` + {name: "ok-name", val: 2}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) {}) + } +} + +func TestTableDrivenMap(t *testing.T) { + tests := map[string]struct { + name string + val int + }{ + "key1": {name: "bad name here", val: 1}, // want `subtest name "bad name here" contains characters that require quoting` + "key2": {name: "ok-name", val: 2}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) {}) + } +} + +func TestNotTesting(t *testing.T) { + // Not a t.Run call, should not trigger. + s := struct{ Run func(string, func()) }{} + s.Run("bad name here", func() {}) +} + +func TestDynamicName(t *testing.T) { + // Dynamic name, not a string literal — should not trigger. + name := getName() + t.Run(name, func(t *testing.T) {}) +} + +func getName() string { return "foo" } + +func BenchmarkDirect(b *testing.B) { + // Also check b.Run. + b.Run("bad name here", func(b *testing.B) {}) // want `subtest name "bad name here" contains characters that require quoting` + b.Run("good-name", func(b *testing.B) {}) +} diff --git a/cmd/vet/vet.go b/cmd/vet/vet.go index babc30d254719..38ffdb6d02891 100644 --- a/cmd/vet/vet.go +++ b/cmd/vet/vet.go @@ -9,6 +9,8 @@ import ( "golang.org/x/tools/go/analysis/unitchecker" "tailscale.com/cmd/vet/jsontags" + "tailscale.com/cmd/vet/lowerell" + "tailscale.com/cmd/vet/subtestnames" ) //go:embed jsontags_allowlist @@ -20,5 +22,5 @@ func init() { } func main() { - unitchecker.Main(jsontags.Analyzer) + unitchecker.Main(jsontags.Analyzer, lowerell.Analyzer, subtestnames.Analyzer) } diff --git a/cmd/viewer/tests/tests.go b/cmd/viewer/tests/tests.go index cbffd38845ec3..060ac9d8e96fa 100644 --- a/cmd/viewer/tests/tests.go +++ b/cmd/viewer/tests/tests.go @@ -9,11 +9,10 @@ import ( "net/netip" "golang.org/x/exp/constraints" - "tailscale.com/types/ptr" "tailscale.com/types/views" ) -//go:generate go run tailscale.com/cmd/viewer --type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded,GenericIntStruct,GenericNoPtrsStruct,GenericCloneableStruct,StructWithContainers,StructWithTypeAliasFields,GenericTypeAliasStruct,StructWithMapOfViews --clone-only-type=OnlyGetClone +//go:generate go run tailscale.com/cmd/viewer --type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded,GenericIntStruct,GenericNoPtrsStruct,GenericCloneableStruct,StructWithContainers,StructWithTypeAliasFields,GenericTypeAliasStruct,StructWithMapOfViews,StructWithNamedMap,StructWithNamedSlice --clone-only-type=OnlyGetClone type StructWithoutPtrs struct { Int int @@ -135,7 +134,7 @@ func (c *Container[T]) Clone() *Container[T] { return &Container[T]{cloner.Clone()} } if !views.ContainsPointers[T]() { - return ptr.To(*c) + return new(*c) } panic(fmt.Errorf("%T contains pointers, but is not cloneable", c.Item)) } @@ -242,3 +241,61 @@ type GenericTypeAliasStruct[T integer, T2 views.ViewCloner[T2, V2], V2 views.Str type StructWithMapOfViews struct { MapOfViews map[string]StructWithoutPtrsView } + +// NamedMap is a named map type with its own Clone and View methods. +// This tests that the viewer calls View() on named map types rather +// than trying to generate a view of the underlying map[string]any. +type NamedMap map[string]any + +func (m NamedMap) Clone() NamedMap { + if m == nil { + return nil + } + m2 := make(NamedMap, len(m)) + for k, v := range m { + m2[k] = v + } + return m2 +} + +// NamedMapView is a read-only view of NamedMap. +type NamedMapView struct { + Đļ NamedMap +} + +func (m NamedMap) View() NamedMapView { return NamedMapView{m} } + +func (v NamedMapView) Get(k string) (any, bool) { val, ok := v.Đļ[k]; return val, ok } +func (v NamedMapView) Len() int { return len(v.Đļ) } + +type StructWithNamedMap struct { + Attrs NamedMap +} + +// NamedSlice is a named slice type with its own Clone and View methods. +// This tests that the viewer calls View() on named slice types rather +// than trying to generate a view of the underlying []any. +type NamedSlice []any + +func (s NamedSlice) Clone() NamedSlice { + if s == nil { + return nil + } + s2 := make(NamedSlice, len(s)) + copy(s2, s) + return s2 +} + +// NamedSliceView is a read-only view of NamedSlice. +type NamedSliceView struct { + Đļ NamedSlice +} + +func (s NamedSlice) View() NamedSliceView { return NamedSliceView{s} } + +func (v NamedSliceView) At(i int) any { return v.Đļ[i] } +func (v NamedSliceView) Len() int { return len(v.Đļ) } + +type StructWithNamedSlice struct { + Items NamedSlice +} diff --git a/cmd/viewer/tests/tests_clone.go b/cmd/viewer/tests/tests_clone.go index cbf5ec2653d98..bc576ec975489 100644 --- a/cmd/viewer/tests/tests_clone.go +++ b/cmd/viewer/tests/tests_clone.go @@ -10,7 +10,6 @@ import ( "net/netip" "golang.org/x/exp/constraints" - "tailscale.com/types/ptr" "tailscale.com/types/views" ) @@ -23,13 +22,13 @@ func (src *StructWithPtrs) Clone() *StructWithPtrs { dst := new(StructWithPtrs) *dst = *src if dst.Value != nil { - dst.Value = ptr.To(*src.Value) + dst.Value = new(*src.Value) } if dst.Int != nil { - dst.Int = ptr.To(*src.Int) + dst.Int = new(*src.Int) } if dst.NoView != nil { - dst.NoView = ptr.To(*src.NoView) + dst.NoView = new(*src.NoView) } return dst } @@ -90,21 +89,43 @@ func (src *Map) Clone() *Map { if v == nil { dst.StructPtrWithoutPtr[k] = nil } else { - dst.StructPtrWithoutPtr[k] = ptr.To(*v) + dst.StructPtrWithoutPtr[k] = new(*v) } } } dst.StructWithoutPtr = maps.Clone(src.StructWithoutPtr) if dst.SlicesWithPtrs != nil { dst.SlicesWithPtrs = map[string][]*StructWithPtrs{} - for k := range src.SlicesWithPtrs { - dst.SlicesWithPtrs[k] = append([]*StructWithPtrs{}, src.SlicesWithPtrs[k]...) + for k, sv := range src.SlicesWithPtrs { + if sv == nil { + dst.SlicesWithPtrs[k] = nil + continue + } + dst.SlicesWithPtrs[k] = make([]*StructWithPtrs, len(sv)) + for i := range sv { + if sv[i] == nil { + dst.SlicesWithPtrs[k][i] = nil + } else { + dst.SlicesWithPtrs[k][i] = sv[i].Clone() + } + } } } if dst.SlicesWithoutPtrs != nil { dst.SlicesWithoutPtrs = map[string][]*StructWithoutPtrs{} - for k := range src.SlicesWithoutPtrs { - dst.SlicesWithoutPtrs[k] = append([]*StructWithoutPtrs{}, src.SlicesWithoutPtrs[k]...) + for k, sv := range src.SlicesWithoutPtrs { + if sv == nil { + dst.SlicesWithoutPtrs[k] = nil + continue + } + dst.SlicesWithoutPtrs[k] = make([]*StructWithoutPtrs, len(sv)) + for i := range sv { + if sv[i] == nil { + dst.SlicesWithoutPtrs[k][i] = nil + } else { + dst.SlicesWithoutPtrs[k][i] = new(*sv[i]) + } + } } } dst.StructWithoutPtrKey = maps.Clone(src.StructWithoutPtrKey) @@ -116,8 +137,19 @@ func (src *Map) Clone() *Map { } if dst.SliceIntPtr != nil { dst.SliceIntPtr = map[string][]*int{} - for k := range src.SliceIntPtr { - dst.SliceIntPtr[k] = append([]*int{}, src.SliceIntPtr[k]...) + for k, sv := range src.SliceIntPtr { + if sv == nil { + dst.SliceIntPtr[k] = nil + continue + } + dst.SliceIntPtr[k] = make([]*int, len(sv)) + for i := range sv { + if sv[i] == nil { + dst.SliceIntPtr[k][i] = nil + } else { + dst.SliceIntPtr[k][i] = new(*sv[i]) + } + } } } dst.PointerKey = maps.Clone(src.PointerKey) @@ -156,7 +188,7 @@ func (src *StructWithSlices) Clone() *StructWithSlices { if src.ValuePointers[i] == nil { dst.ValuePointers[i] = nil } else { - dst.ValuePointers[i] = ptr.To(*src.ValuePointers[i]) + dst.ValuePointers[i] = new(*src.ValuePointers[i]) } } } @@ -185,7 +217,7 @@ func (src *StructWithSlices) Clone() *StructWithSlices { if src.Ints[i] == nil { dst.Ints[i] = nil } else { - dst.Ints[i] = ptr.To(*src.Ints[i]) + dst.Ints[i] = new(*src.Ints[i]) } } } @@ -248,7 +280,7 @@ func (src *GenericIntStruct[T]) Clone() *GenericIntStruct[T] { dst := new(GenericIntStruct[T]) *dst = *src if dst.Pointer != nil { - dst.Pointer = ptr.To(*src.Pointer) + dst.Pointer = new(*src.Pointer) } dst.Slice = append(src.Slice[:0:0], src.Slice...) dst.Map = maps.Clone(src.Map) @@ -258,7 +290,7 @@ func (src *GenericIntStruct[T]) Clone() *GenericIntStruct[T] { if src.PtrSlice[i] == nil { dst.PtrSlice[i] = nil } else { - dst.PtrSlice[i] = ptr.To(*src.PtrSlice[i]) + dst.PtrSlice[i] = new(*src.PtrSlice[i]) } } } @@ -269,7 +301,7 @@ func (src *GenericIntStruct[T]) Clone() *GenericIntStruct[T] { if v == nil { dst.PtrValueMap[k] = nil } else { - dst.PtrValueMap[k] = ptr.To(*v) + dst.PtrValueMap[k] = new(*v) } } } @@ -305,7 +337,7 @@ func (src *GenericNoPtrsStruct[T]) Clone() *GenericNoPtrsStruct[T] { dst := new(GenericNoPtrsStruct[T]) *dst = *src if dst.Pointer != nil { - dst.Pointer = ptr.To(*src.Pointer) + dst.Pointer = new(*src.Pointer) } dst.Slice = append(src.Slice[:0:0], src.Slice...) dst.Map = maps.Clone(src.Map) @@ -315,7 +347,7 @@ func (src *GenericNoPtrsStruct[T]) Clone() *GenericNoPtrsStruct[T] { if src.PtrSlice[i] == nil { dst.PtrSlice[i] = nil } else { - dst.PtrSlice[i] = ptr.To(*src.PtrSlice[i]) + dst.PtrSlice[i] = new(*src.PtrSlice[i]) } } } @@ -326,7 +358,7 @@ func (src *GenericNoPtrsStruct[T]) Clone() *GenericNoPtrsStruct[T] { if v == nil { dst.PtrValueMap[k] = nil } else { - dst.PtrValueMap[k] = ptr.To(*v) + dst.PtrValueMap[k] = new(*v) } } } @@ -375,7 +407,7 @@ func (src *GenericCloneableStruct[T, V]) Clone() *GenericCloneableStruct[T, V] { } } if dst.Pointer != nil { - dst.Pointer = ptr.To((*src.Pointer).Clone()) + dst.Pointer = new((*src.Pointer).Clone()) } if src.PtrSlice != nil { dst.PtrSlice = make([]*T, len(src.PtrSlice)) @@ -383,7 +415,7 @@ func (src *GenericCloneableStruct[T, V]) Clone() *GenericCloneableStruct[T, V] { if src.PtrSlice[i] == nil { dst.PtrSlice[i] = nil } else { - dst.PtrSlice[i] = ptr.To((*src.PtrSlice[i]).Clone()) + dst.PtrSlice[i] = new((*src.PtrSlice[i]).Clone()) } } } @@ -394,14 +426,21 @@ func (src *GenericCloneableStruct[T, V]) Clone() *GenericCloneableStruct[T, V] { if v == nil { dst.PtrValueMap[k] = nil } else { - dst.PtrValueMap[k] = ptr.To((*v).Clone()) + dst.PtrValueMap[k] = new((*v).Clone()) } } } if dst.SliceMap != nil { dst.SliceMap = map[string][]T{} - for k := range src.SliceMap { - dst.SliceMap[k] = append([]T{}, src.SliceMap[k]...) + for k, sv := range src.SliceMap { + if sv == nil { + dst.SliceMap[k] = nil + continue + } + dst.SliceMap[k] = make([]T, len(sv)) + for i := range sv { + dst.SliceMap[k][i] = sv[i].Clone() + } } } return dst @@ -457,7 +496,7 @@ func (src *StructWithTypeAliasFields) Clone() *StructWithTypeAliasFields { dst.WithPtr = *src.WithPtr.Clone() dst.WithPtrByPtr = src.WithPtrByPtr.Clone() if dst.WithoutPtrByPtr != nil { - dst.WithoutPtrByPtr = ptr.To(*src.WithoutPtrByPtr) + dst.WithoutPtrByPtr = new(*src.WithoutPtrByPtr) } if src.SliceWithPtrs != nil { dst.SliceWithPtrs = make([]*StructWithPtrsAlias, len(src.SliceWithPtrs)) @@ -475,7 +514,7 @@ func (src *StructWithTypeAliasFields) Clone() *StructWithTypeAliasFields { if src.SliceWithoutPtrs[i] == nil { dst.SliceWithoutPtrs[i] = nil } else { - dst.SliceWithoutPtrs[i] = ptr.To(*src.SliceWithoutPtrs[i]) + dst.SliceWithoutPtrs[i] = new(*src.SliceWithoutPtrs[i]) } } } @@ -495,20 +534,42 @@ func (src *StructWithTypeAliasFields) Clone() *StructWithTypeAliasFields { if v == nil { dst.MapWithoutPtrs[k] = nil } else { - dst.MapWithoutPtrs[k] = ptr.To(*v) + dst.MapWithoutPtrs[k] = new(*v) } } } if dst.MapOfSlicesWithPtrs != nil { dst.MapOfSlicesWithPtrs = map[string][]*StructWithPtrsAlias{} - for k := range src.MapOfSlicesWithPtrs { - dst.MapOfSlicesWithPtrs[k] = append([]*StructWithPtrsAlias{}, src.MapOfSlicesWithPtrs[k]...) + for k, sv := range src.MapOfSlicesWithPtrs { + if sv == nil { + dst.MapOfSlicesWithPtrs[k] = nil + continue + } + dst.MapOfSlicesWithPtrs[k] = make([]*StructWithPtrsAlias, len(sv)) + for i := range sv { + if sv[i] == nil { + dst.MapOfSlicesWithPtrs[k][i] = nil + } else { + dst.MapOfSlicesWithPtrs[k][i] = sv[i].Clone() + } + } } } if dst.MapOfSlicesWithoutPtrs != nil { dst.MapOfSlicesWithoutPtrs = map[string][]*StructWithoutPtrsAlias{} - for k := range src.MapOfSlicesWithoutPtrs { - dst.MapOfSlicesWithoutPtrs[k] = append([]*StructWithoutPtrsAlias{}, src.MapOfSlicesWithoutPtrs[k]...) + for k, sv := range src.MapOfSlicesWithoutPtrs { + if sv == nil { + dst.MapOfSlicesWithoutPtrs[k] = nil + continue + } + dst.MapOfSlicesWithoutPtrs[k] = make([]*StructWithoutPtrsAlias, len(sv)) + for i := range sv { + if sv[i] == nil { + dst.MapOfSlicesWithoutPtrs[k][i] = nil + } else { + dst.MapOfSlicesWithoutPtrs[k][i] = new(*sv[i]) + } + } } } return dst @@ -564,3 +625,37 @@ func (src *StructWithMapOfViews) Clone() *StructWithMapOfViews { var _StructWithMapOfViewsCloneNeedsRegeneration = StructWithMapOfViews(struct { MapOfViews map[string]StructWithoutPtrsView }{}) + +// Clone makes a deep copy of StructWithNamedMap. +// The result aliases no memory with the original. +func (src *StructWithNamedMap) Clone() *StructWithNamedMap { + if src == nil { + return nil + } + dst := new(StructWithNamedMap) + *dst = *src + dst.Attrs = src.Attrs.Clone() + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _StructWithNamedMapCloneNeedsRegeneration = StructWithNamedMap(struct { + Attrs NamedMap +}{}) + +// Clone makes a deep copy of StructWithNamedSlice. +// The result aliases no memory with the original. +func (src *StructWithNamedSlice) Clone() *StructWithNamedSlice { + if src == nil { + return nil + } + dst := new(StructWithNamedSlice) + *dst = *src + dst.Items = src.Items.Clone() + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _StructWithNamedSliceCloneNeedsRegeneration = StructWithNamedSlice(struct { + Items NamedSlice +}{}) diff --git a/cmd/viewer/tests/tests_view.go b/cmd/viewer/tests/tests_view.go index fe073446ea200..29be2d78bb232 100644 --- a/cmd/viewer/tests/tests_view.go +++ b/cmd/viewer/tests/tests_view.go @@ -16,7 +16,7 @@ import ( "tailscale.com/types/views" ) -//go:generate go run tailscale.com/cmd/cloner -clonefunc=false -type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded,GenericIntStruct,GenericNoPtrsStruct,GenericCloneableStruct,StructWithContainers,StructWithTypeAliasFields,GenericTypeAliasStruct,StructWithMapOfViews +//go:generate go run tailscale.com/cmd/cloner -clonefunc=false -type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded,GenericIntStruct,GenericNoPtrsStruct,GenericCloneableStruct,StructWithContainers,StructWithTypeAliasFields,GenericTypeAliasStruct,StructWithMapOfViews,StructWithNamedMap,StructWithNamedSlice // View returns a read-only view of StructWithPtrs. func (p *StructWithPtrs) View() StructWithPtrsView { @@ -1129,3 +1129,151 @@ func (v StructWithMapOfViewsView) MapOfViews() views.Map[string, StructWithoutPt var _StructWithMapOfViewsViewNeedsRegeneration = StructWithMapOfViews(struct { MapOfViews map[string]StructWithoutPtrsView }{}) + +// View returns a read-only view of StructWithNamedMap. +func (p *StructWithNamedMap) View() StructWithNamedMapView { + return StructWithNamedMapView{Đļ: p} +} + +// StructWithNamedMapView provides a read-only view over StructWithNamedMap. +// +// Its methods should only be called if `Valid()` returns true. +type StructWithNamedMapView struct { + // Đļ is the underlying mutable value, named with a hard-to-type + // character that looks pointy like a pointer. + // It is named distinctively to make you think of how dangerous it is to escape + // to callers. You must not let callers be able to mutate it. + Đļ *StructWithNamedMap +} + +// Valid reports whether v's underlying value is non-nil. +func (v StructWithNamedMapView) Valid() bool { return v.Đļ != nil } + +// AsStruct returns a clone of the underlying value which aliases no memory with +// the original. +func (v StructWithNamedMapView) AsStruct() *StructWithNamedMap { + if v.Đļ == nil { + return nil + } + return v.Đļ.Clone() +} + +// MarshalJSON implements [jsonv1.Marshaler]. +func (v StructWithNamedMapView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v StructWithNamedMapView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} + +// UnmarshalJSON implements [jsonv1.Unmarshaler]. +func (v *StructWithNamedMapView) UnmarshalJSON(b []byte) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + if len(b) == 0 { + return nil + } + var x StructWithNamedMap + if err := jsonv1.Unmarshal(b, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *StructWithNamedMapView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x StructWithNamedMap + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +func (v StructWithNamedMapView) Attrs() NamedMapView { return v.Đļ.Attrs.View() } + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _StructWithNamedMapViewNeedsRegeneration = StructWithNamedMap(struct { + Attrs NamedMap +}{}) + +// View returns a read-only view of StructWithNamedSlice. +func (p *StructWithNamedSlice) View() StructWithNamedSliceView { + return StructWithNamedSliceView{Đļ: p} +} + +// StructWithNamedSliceView provides a read-only view over StructWithNamedSlice. +// +// Its methods should only be called if `Valid()` returns true. +type StructWithNamedSliceView struct { + // Đļ is the underlying mutable value, named with a hard-to-type + // character that looks pointy like a pointer. + // It is named distinctively to make you think of how dangerous it is to escape + // to callers. You must not let callers be able to mutate it. + Đļ *StructWithNamedSlice +} + +// Valid reports whether v's underlying value is non-nil. +func (v StructWithNamedSliceView) Valid() bool { return v.Đļ != nil } + +// AsStruct returns a clone of the underlying value which aliases no memory with +// the original. +func (v StructWithNamedSliceView) AsStruct() *StructWithNamedSlice { + if v.Đļ == nil { + return nil + } + return v.Đļ.Clone() +} + +// MarshalJSON implements [jsonv1.Marshaler]. +func (v StructWithNamedSliceView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v StructWithNamedSliceView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} + +// UnmarshalJSON implements [jsonv1.Unmarshaler]. +func (v *StructWithNamedSliceView) UnmarshalJSON(b []byte) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + if len(b) == 0 { + return nil + } + var x StructWithNamedSlice + if err := jsonv1.Unmarshal(b, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *StructWithNamedSliceView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x StructWithNamedSlice + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +func (v StructWithNamedSliceView) Items() NamedSliceView { return v.Đļ.Items.View() } + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _StructWithNamedSliceViewNeedsRegeneration = StructWithNamedSlice(struct { + Items NamedSlice +}{}) diff --git a/cmd/viewer/viewer.go b/cmd/viewer/viewer.go index 56b999f5f50fe..21d417878df56 100644 --- a/cmd/viewer/viewer.go +++ b/cmd/viewer/viewer.go @@ -282,6 +282,22 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, fie writeTemplateWithComment("valueField", fname) continue } + // Named map/slice types whose element type is opaque (e.g. any) + // can't be safely wrapped in views.Map/views.Slice because the + // accessor would leak the raw element. If the type provides its + // own View() method the author can return a purpose-built safe + // view; use it. Otherwise fall through to the normal handling, + // which will reject the type as unsupported. + if named, _ := codegen.NamedTypeOf(fieldType); named != nil { + switch fieldType.Underlying().(type) { + case *types.Map, *types.Slice: + if viewType := viewTypeForValueType(fieldType); viewType != nil { + args.FieldViewName = it.QualifiedName(viewType) + writeTemplateWithComment("viewField", fname) + continue + } + } + } switch underlying := fieldType.Underlying().(type) { case *types.Slice: slice := underlying @@ -500,8 +516,7 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, fie } writeTemplateWithComment("unsupportedField", fname) } - for i := range typ.NumMethods() { - f := typ.Method(i) + for f := range typ.Methods() { if !f.Exported() { continue } @@ -720,7 +735,7 @@ func main() { fieldComments := getFieldComments(pkg.Syntax) cloneOnlyType := map[string]bool{} - for _, t := range strings.Split(*flagCloneOnlyTypes, ",") { + for t := range strings.SplitSeq(*flagCloneOnlyTypes, ",") { cloneOnlyType[t] = true } diff --git a/cmd/viewer/viewer_test.go b/cmd/viewer/viewer_test.go index 8bd18d4806ae2..e53f1d3c56a4a 100644 --- a/cmd/viewer/viewer_test.go +++ b/cmd/viewer/viewer_test.go @@ -10,11 +10,109 @@ import ( "go/parser" "go/token" "go/types" + "strings" "testing" "tailscale.com/util/codegen" ) +// TestNamedMapWithView tests that a named map type with a user-supplied +// View() method causes the generated view accessor to call .View() and +// return the user-defined view type. Without the View() method the +// generator should reject the field as unsupported. +func TestNamedMapWithView(t *testing.T) { + const src = ` +package test + +// AttrMap is a named map whose values are opaque (any). +// It provides its own Clone and View methods. +type AttrMap map[string]any + +func (m AttrMap) Clone() AttrMap { + m2 := make(AttrMap, len(m)) + for k, v := range m { m2[k] = v } + return m2 +} + +// AttrMapView is a hand-written read-only view of AttrMap. +type AttrMapView struct{ m AttrMap } + +func (m AttrMap) View() AttrMapView { return AttrMapView{m} } + +// Container holds an AttrMap field. +type Container struct { + Attrs AttrMap +} +` + output := genViewOutput(t, src, "Container") + + // The generated accessor must call .View() and return the + // user-defined AttrMapView, not views.Map or the raw AttrMap. + const want = "func (v ContainerView) Attrs() AttrMapView { return v.Đļ.Attrs.View() }" + if !strings.Contains(output, want) { + t.Errorf("generated output missing expected accessor\nwant: %s\ngot:\n%s", want, output) + } +} + +// TestNamedMapWithoutView tests that a named map[string]any WITHOUT a +// View() method does NOT generate an accessor that calls .View(). +func TestNamedMapWithoutView(t *testing.T) { + const src = ` +package test + +type AttrMap map[string]any + +func (m AttrMap) Clone() AttrMap { + m2 := make(AttrMap, len(m)) + for k, v := range m { m2[k] = v } + return m2 +} + +type Container struct { + Attrs AttrMap +} +` + output := genViewOutput(t, src, "Container") + + // Must not generate an accessor that calls .Attrs.View(), + // since AttrMap doesn't have a View() method. + if strings.Contains(output, "Attrs.View()") { + t.Errorf("generated code calls .Attrs.View() but AttrMap has no View method:\n%s", output) + } + // Must not return AttrMapView (which doesn't exist). + if strings.Contains(output, "AttrMapView") { + t.Errorf("generated code references AttrMapView but it doesn't exist:\n%s", output) + } +} + +// genViewOutput parses src, runs genView on the named type, and returns +// the generated Go source. +func genViewOutput(t *testing.T, src string, typeName string) string { + t.Helper() + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, "test.go", src, 0) + if err != nil { + t.Fatal(err) + } + conf := types.Config{} + pkg, err := conf.Check("test", fset, []*ast.File{f}, nil) + if err != nil { + t.Fatal(err) + } + obj := pkg.Scope().Lookup(typeName) + if obj == nil { + t.Fatalf("type %q not found", typeName) + } + named, ok := obj.(*types.TypeName).Type().(*types.Named) + if !ok { + t.Fatalf("%q is not a named type", typeName) + } + var buf bytes.Buffer + tracker := codegen.NewImportTracker(pkg) + genView(&buf, tracker, named, nil) + return buf.String() +} + func TestViewerImports(t *testing.T) { tests := []struct { name string diff --git a/control/controlbase/conn_test.go b/control/controlbase/conn_test.go index a1e2b313de5b6..202d39efae9e7 100644 --- a/control/controlbase/conn_test.go +++ b/control/controlbase/conn_test.go @@ -11,12 +11,11 @@ import ( "fmt" "io" "net" - "runtime" "strings" "sync" "testing" "testing/iotest" - "time" + "testing/synctest" chp "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/nettest" @@ -226,79 +225,31 @@ func TestConnStd(t *testing.T) { }) } -// tests that the idle memory overhead of a Conn blocked in a read is -// reasonable (under 2K). It was previously over 8KB with two 4KB -// buffers for rx/tx. This make sure we don't regress. Hopefully it -// doesn't turn into a flaky test. If so, const max can be adjusted, -// or it can be deleted or reworked. +// tests that the memory overhead of a Conn blocked in a read is +// reasonable. It was previously over 8KB with two 4KB buffers for +// rx/tx. This makes sure we don't regress. func TestConnMemoryOverhead(t *testing.T) { - num := 1000 - if testing.Short() { - num = 100 - } - ng0 := runtime.NumGoroutine() - - runtime.GC() - var ms0 runtime.MemStats - runtime.ReadMemStats(&ms0) - - var closers []io.Closer - closeAll := func() { - for _, c := range closers { - c.Close() - } - closers = nil - } - defer closeAll() - - for range num { - client, server := pair(t) - closers = append(closers, client, server) - go func() { - var buf [1]byte - client.Read(buf[:]) - }() - } - - t0 := time.Now() - deadline := t0.Add(3 * time.Second) - var ngo int - for time.Now().Before(deadline) { - runtime.GC() - ngo = runtime.NumGoroutine() - if ngo >= num { - break + synctest.Test(t, func(t *testing.T) { + // AllocsPerRun runs the function once for warmup (filling + // allocator slab caches, etc.) and then measures over the + // remaining runs, returning the average allocation count. + allocs := testing.AllocsPerRun(100, func() { + client, server := pair(t) + go func() { + var buf [1]byte + client.Read(buf[:]) + }() + synctest.Wait() + client.Close() + server.Close() + synctest.Wait() + }) + t.Logf("allocs per blocked-conn pair: %v", allocs) + const max = 400 + if allocs > max { + t.Errorf("allocs per blocked-conn pair = %v, want <= %v", allocs, max) } - time.Sleep(10 * time.Millisecond) - } - if ngo < num { - t.Fatalf("only %v goroutines; expected %v+", ngo, num) - } - runtime.GC() - var ms runtime.MemStats - runtime.ReadMemStats(&ms) - growthTotal := int64(ms.HeapAlloc) - int64(ms0.HeapAlloc) - growthEach := float64(growthTotal) / float64(num) - t.Logf("Alloced %v bytes, %.2f B/each", growthTotal, growthEach) - const max = 2048 - if growthEach > max { - t.Errorf("allocated more than expected; want max %v bytes/each", max) - } - - closeAll() - - // And make sure our goroutines go away too. - deadline = time.Now().Add(3 * time.Second) - for time.Now().Before(deadline) { - ngo = runtime.NumGoroutine() - if ngo < ng0+num/10 { - break - } - time.Sleep(10 * time.Millisecond) - } - if ngo >= ng0+num/10 { - t.Errorf("goroutines didn't go back down; started at %v, now %v", ng0, ngo) - } + }) } type readSink struct { diff --git a/control/controlclient/auto.go b/control/controlclient/auto.go index 783ca36c4f45d..05c7552c82185 100644 --- a/control/controlclient/auto.go +++ b/control/controlclient/auto.go @@ -91,7 +91,7 @@ func (c *Auto) updateRoutine() { bo.BackOff(ctx, err) continue } - bo.BackOff(ctx, nil) + bo.Reset() c.direct.logf("[v1] successful lite map update in %v", d) lastUpdateGenInformed = gen @@ -356,7 +356,15 @@ func (c *Auto) authRoutine() { if err != nil { c.direct.health.SetAuthRoutineInError(err) report(err, f) - bo.BackOff(ctx, err) + if rle, ok := errors.AsType[*rateLimitError](err); ok { + c.logf("authRoutine: %s", rle) + select { + case <-ctx.Done(): + case <-time.After(rle.retryAfter): + } + } else { + bo.BackOff(ctx, err) + } continue } if url != "" { @@ -382,7 +390,7 @@ func (c *Auto) authRoutine() { // backoff to avoid a busy loop. bo.BackOff(ctx, errors.New("login URL not changing")) } else { - bo.BackOff(ctx, nil) + bo.Reset() } continue } @@ -397,7 +405,7 @@ func (c *Auto) authRoutine() { c.sendStatus("authRoutine-success", nil, "", nil) c.restartMap() - bo.BackOff(ctx, nil) + bo.Reset() } } @@ -446,13 +454,14 @@ func (mrs mapRoutineState) UpdateFullNetmap(nm *netmap.NetworkMap) { c.expiry = nm.SelfKeyExpiry() stillAuthed := c.loggedIn c.logf("[v1] mapRoutine: netmap received: loggedIn=%v inMapPoll=true", stillAuthed) + + // Reset the backoff timer if we got a netmap. + mrs.bo.Reset() c.mu.Unlock() if stillAuthed { c.sendStatus("mapRoutine-got-netmap", nil, "", nm) } - // Reset the backoff timer if we got a netmap. - mrs.bo.Reset() } func (mrs mapRoutineState) UpdateNetmapDelta(muts []netmap.NodeMutation) bool { @@ -477,6 +486,27 @@ func (mrs mapRoutineState) UpdateNetmapDelta(muts []netmap.NodeMutation) bool { return err == nil && ok } +var _ patchDiscoKeyer = mapRoutineState{} + +func (mrs mapRoutineState) PatchDiscoKey(pub key.NodePublic, disco key.DiscoPublic) { + c := mrs.c + c.mu.Lock() + goodState := c.loggedIn && c.inMapPoll + dun, ok := c.observer.(patchDiscoKeyer) + c.mu.Unlock() + + if !goodState || !ok { + return + } + + ctx, cancel := context.WithTimeout(c.mapCtx, 2*time.Second) + defer cancel() + + c.observerQueue.RunSync(ctx, func() { + dun.PatchDiscoKey(pub, disco) + }) +} + // mapRoutine is responsible for keeping a read-only streaming connection to the // control server, and keeping the netmap up to date. func (c *Auto) mapRoutine() { @@ -526,13 +556,18 @@ func (c *Auto) mapRoutine() { c.mu.Lock() c.inMapPoll = false paused := c.paused + + if paused { + mrs.bo.Reset() + } else { + mrs.bo.BackOff(ctx, err) + } c.mu.Unlock() + // Now safe to call functions that might acquire the mutex if paused { - mrs.bo.BackOff(ctx, nil) c.logf("mapRoutine: paused") } else { - mrs.bo.BackOff(ctx, err) report(err, "PollNetMap") } } @@ -773,6 +808,15 @@ func (c *Auto) SetDiscoPublicKey(key key.DiscoPublic) { c.updateControl() } +// SetIPForwardingBroken updates the IP forwarding broken state and sends +// a control update if the value changed. +func (c *Auto) SetIPForwardingBroken(v bool) { + if !c.direct.SetIPForwardingBroken(v) { + return + } + c.updateControl() +} + func (c *Auto) Shutdown() { c.mu.Lock() if c.closed { diff --git a/control/controlclient/client.go b/control/controlclient/client.go index a57c6940a88c4..7d2eaa4fef763 100644 --- a/control/controlclient/client.go +++ b/control/controlclient/client.go @@ -87,6 +87,9 @@ type Client interface { // future map requests. This should be called after rotating the discovery key. // Note: the auto client uploads the new key to control immediately. SetDiscoPublicKey(key.DiscoPublic) + // SetIPForwardingBroken updates the IP forwarding broken state + // and sends a control update if the value changed. + SetIPForwardingBroken(bool) // ClientID returns the ClientID of a client. This ID is meant to // distinguish one client from another. ClientID() int64 diff --git a/control/controlclient/controlclient_test.go b/control/controlclient/controlclient_test.go index dca1d8ddf2f8b..5c25af0f4b433 100644 --- a/control/controlclient/controlclient_test.go +++ b/control/controlclient/controlclient_test.go @@ -38,8 +38,8 @@ import ( ) func fieldsOf(t reflect.Type) (fields []string) { - for i := range t.NumField() { - if name := t.Field(i).Name; name != "_" { + for field := range t.Fields() { + if name := field.Name; name != "_" { fields = append(fields, name) } } @@ -214,12 +214,12 @@ func TestRetryableErrors(t *testing.T) { } type retryableForTest interface { + error Retryable() bool } func isRetryableErrorForTest(err error) bool { - var ae retryableForTest - if errors.As(err, &ae) { + if ae, ok := errors.AsType[retryableForTest](err); ok { return ae.Retryable() } return false @@ -406,6 +406,118 @@ func testHTTPS(t *testing.T, withProxy bool) { } } +// TestRegisterRateLimited verifies that the client correctly handles 429 +// responses to registration requests by parsing the Retry-After header +// and returning a rateLimitError. +func TestRegisterRateLimited(t *testing.T) { + bakedroots.ResetForTest(t, tlstest.TestRootCA()) + + bus := eventbustest.NewBus(t) + + controlLn, err := tls.Listen("tcp", "127.0.0.1:0", tlstest.ControlPlane.ServerTLSConfig()) + if err != nil { + t.Fatal(err) + } + defer controlLn.Close() + + var registerAttempts atomic.Int64 + tc := &testcontrol.Server{ + Logf: tstest.WhileTestRunningLogger(t), + MaybeRateLimitRegister: func() (bool, string, string) { + if registerAttempts.Add(1) == 1 { + return true, "30", "try again later" + } + return false, "", "" + }, + } + controlSrv := &http.Server{ + Handler: tc, + ErrorLog: logger.StdLogger(t.Logf), + } + go controlSrv.Serve(controlLn) + + const fakeControlIP = "1.2.3.4" + + dialer := &tsdial.Dialer{} + dialer.SetNetMon(netmon.NewStatic()) + dialer.SetBus(bus) + dialer.SetSystemDialerForTest(func(ctx context.Context, network, addr string) (net.Conn, error) { + host, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, fmt.Errorf("SplitHostPort(%q): %v", addr, err) + } + var d net.Dialer + if host == fakeControlIP { + return d.DialContext(ctx, network, controlLn.Addr().String()) + } + return nil, fmt.Errorf("unexpected dial to %q", addr) + }) + + opts := Options{ + Persist: persist.Persist{}, + GetMachinePrivateKey: func() (key.MachinePrivate, error) { + return key.NewMachine(), nil + }, + ServerURL: "https://controlplane.tstest", + Clock: tstime.StdClock{}, + Hostinfo: &tailcfg.Hostinfo{ + BackendLogID: "test-backend-log-id", + }, + DiscoPublicKey: key.NewDisco().Public(), + Logf: t.Logf, + HealthTracker: health.NewTracker(bus), + PopBrowserURL: func(url string) { + t.Logf("PopBrowserURL: %q", url) + }, + Dialer: dialer, + Bus: bus, + } + d, err := NewDirect(opts) + if err != nil { + t.Fatalf("NewDirect: %v", err) + } + + d.dnsCache.LookupIPForTest = func(ctx context.Context, host string) ([]netip.Addr, error) { + if host == "controlplane.tstest" { + return []netip.Addr{netip.MustParseAddr(fakeControlIP)}, nil + } + t.Errorf("unexpected DNS query for %q", host) + return nil, fmt.Errorf("unexpected DNS lookup for %q", host) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // First attempt should get a 429 and return a rateLimitError. + _, err = d.TryLogin(ctx, LoginEphemeral) + if err == nil { + t.Fatal("expected rate limit error on first attempt, got nil") + } + var rle *rateLimitError + if !errors.As(err, &rle) { + t.Fatalf("expected *rateLimitError, got %T: %v", err, err) + } + if rle.retryAfter != 30*time.Second { + t.Errorf("retryAfter = %v, want 30s", rle.retryAfter) + } + if rle.msg != "try again later" { + t.Errorf("msg = %q, want %q", rle.msg, "try again later") + } + + // Second attempt should succeed (server no longer rate-limiting). + url, err := d.TryLogin(ctx, LoginEphemeral) + if err != nil { + t.Fatalf("TryLogin after rate limit: %v", err) + } + if url != "" { + t.Errorf("got URL %q, want empty", url) + } + + if got := registerAttempts.Load(); got != 2 { + t.Errorf("register attempts = %d, want 2", got) + } +} + func connectProxyTo(t testing.TB, target, backendAddrPort string, reqs *atomic.Int64) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.RequestURI != target { diff --git a/control/controlclient/direct.go b/control/controlclient/direct.go index 6f3393b18dfdf..032999cb9c7f5 100644 --- a/control/controlclient/direct.go +++ b/control/controlclient/direct.go @@ -9,12 +9,15 @@ import ( "context" "crypto" "crypto/sha256" + "crypto/tls" + "crypto/x509" "encoding/binary" "encoding/json" "errors" "fmt" "io" "log" + "math/rand/v2" "net" "net/http" "net/netip" @@ -22,6 +25,7 @@ import ( "reflect" "runtime" "slices" + "strconv" "strings" "sync/atomic" "time" @@ -39,7 +43,6 @@ import ( "tailscale.com/net/dnscache" "tailscale.com/net/dnsfallback" "tailscale.com/net/netmon" - "tailscale.com/net/netutil" "tailscale.com/net/netx" "tailscale.com/net/tlsdial" "tailscale.com/net/tsdial" @@ -47,11 +50,11 @@ import ( "tailscale.com/tailcfg" "tailscale.com/tka" "tailscale.com/tstime" + "tailscale.com/types/events" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/netmap" "tailscale.com/types/persist" - "tailscale.com/types/ptr" "tailscale.com/types/tkatype" "tailscale.com/util/clientmetric" "tailscale.com/util/eventbus" @@ -65,30 +68,30 @@ import ( // Direct is the client that connects to a tailcontrol server for a node. type Direct struct { - httpc *http.Client // HTTP client used to do TLS requests to control (just https://controlplane.tailscale.com/key?v=123) - interceptedDial *atomic.Bool // if non-nil, pointer to bool whether ScreenTime intercepted our dial - dialer *tsdial.Dialer - dnsCache *dnscache.Resolver - controlKnobs *controlknobs.Knobs // always non-nil - serverURL string // URL of the tailcontrol server - clock tstime.Clock - logf logger.Logf - netMon *netmon.Monitor // non-nil - health *health.Tracker - busClient *eventbus.Client - clientVersionPub *eventbus.Publisher[tailcfg.ClientVersion] - autoUpdatePub *eventbus.Publisher[AutoUpdate] - controlTimePub *eventbus.Publisher[ControlTime] - getMachinePrivKey func() (key.MachinePrivate, error) - debugFlags []string - skipIPForwardingCheck bool - pinger Pinger - popBrowser func(url string) // or nil - polc policyclient.Client // always non-nil - c2nHandler http.Handler // or nil - panicOnUse bool // if true, panic if client is used (for testing) - closedCtx context.Context // alive until Direct.Close is called - closeCtx context.CancelFunc // cancels closedCtx + httpc *http.Client // HTTP client used to do TLS requests to control (just https://controlplane.tailscale.com/key?v=123) + interceptedDial *atomic.Bool // if non-nil, pointer to bool whether ScreenTime intercepted our dial + dialer *tsdial.Dialer + dnsCache *dnscache.Resolver + controlKnobs *controlknobs.Knobs // always non-nil + serverURL string // URL of the tailcontrol server + clock tstime.Clock + logf logger.Logf + netMon *netmon.Monitor // non-nil + health *health.Tracker + extraRootCAs *x509.CertPool // additional trusted root CAs; or nil + busClient *eventbus.Client + clientVersionPub *eventbus.Publisher[tailcfg.ClientVersion] + autoUpdatePub *eventbus.Publisher[AutoUpdate] + controlTimePub *eventbus.Publisher[ControlTime] + getMachinePrivKey func() (key.MachinePrivate, error) + debugFlags []string + pinger Pinger + popBrowser func(url string) // or nil + polc policyclient.Client // always non-nil + c2nHandler http.Handler // or nil + panicOnUse bool // if true, panic if client is used (for testing) + closedCtx context.Context // alive until Direct.Close is called + closeCtx context.CancelFunc // cancels closedCtx dialPlan ControlDialPlanner // can be nil @@ -96,6 +99,7 @@ type Direct struct { serverLegacyKey key.MachinePublic // original ("legacy") nacl crypto_box-based public key; only used for signRegisterRequest on Windows now serverNoiseKey key.MachinePublic discoPubKey key.DiscoPublic // protected by mu; can be updated via [SetDiscoPublicKey] + ipForwardBroken bool // protected by mu; can be updated via [SetIPForwardingBroken] sfGroup singleflight.Group[struct{}, *ts2021.Client] // protects noiseClient creation. noiseClient *ts2021.Client // also protected by mu @@ -108,8 +112,9 @@ type Direct struct { netinfo *tailcfg.NetInfo endpoints []tailcfg.Endpoint tkaHead string - lastPingURL string // last PingRequest.URL received, for dup suppression - connectionHandleForTest string // sent in MapRequest.ConnectionHandleForTest + lastPingURL string // last PingRequest.URL received, for dup suppression + connectionHandleForTest string // sent in MapRequest.ConnectionHandleForTest + streamingMapSession *mapSession // the one streaming mapSession instance controlClientID int64 // Random ID used to differentiate clients for consumers of messages. } @@ -141,6 +146,7 @@ type Options struct { NoiseTestClient *http.Client // optional HTTP client to use for noise RPCs (tests only) DebugFlags []string // debug settings to send to control HealthTracker *health.Tracker + ExtraRootCAs *x509.CertPool // additional trusted root CAs; or nil PopBrowserURL func(url string) // optional func to open browser Dialer *tsdial.Dialer // non-nil C2NHandler http.Handler // or nil @@ -160,11 +166,6 @@ type Options struct { // If nil, no status updates are reported. Observer Observer - // SkipIPForwardingCheck declares that the host's IP - // forwarding works and should not be double-checked by the - // controlclient package. - SkipIPForwardingCheck bool - // Pinger optionally specifies the Pinger to use to satisfy // MapResponse.PingRequest queries from the control plane. // If nil, PingRequest queries are not answered. @@ -233,6 +234,15 @@ type NetmapDeltaUpdater interface { UpdateNetmapDelta([]netmap.NodeMutation) (ok bool) } +// patchDiscoKeyer is an optional interface that can be implemented by an [Observer] to be +// notified about node disco keys received out-of-band from control, via +// existing connection state. +type patchDiscoKeyer interface { + // PatchDiscoKey reports to the receiver that the specified disco key + // for node was obtained out-of-band from control. + PatchDiscoKey(key.NodePublic, key.DiscoPublic) +} + var nextControlClientID atomic.Int64 // NewDirect returns a new Direct client. @@ -293,6 +303,12 @@ func NewDirect(opts Options) (*Direct, error) { f(tr) } } + if opts.ExtraRootCAs != nil { + if tr.TLSClientConfig == nil { + tr.TLSClientConfig = &tls.Config{} + } + tr.TLSClientConfig.RootCAs = opts.ExtraRootCAs + } tr.TLSClientConfig = tlsdial.Config(opts.HealthTracker, tr.TLSClientConfig) var dialFunc netx.DialFunc dialFunc, interceptedDial = makeScreenTimeDetectingDialFunc(opts.Dialer.SystemDial) @@ -308,26 +324,26 @@ func NewDirect(opts Options) (*Direct, error) { } c := &Direct{ - httpc: httpc, - interceptedDial: interceptedDial, - controlKnobs: opts.ControlKnobs, - getMachinePrivKey: opts.GetMachinePrivateKey, - serverURL: opts.ServerURL, - clock: opts.Clock, - logf: opts.Logf, - persist: opts.Persist.View(), - authKey: opts.AuthKey, - debugFlags: opts.DebugFlags, - netMon: netMon, - health: opts.HealthTracker, - skipIPForwardingCheck: opts.SkipIPForwardingCheck, - pinger: opts.Pinger, - polc: cmp.Or(opts.PolicyClient, policyclient.Client(policyclient.NoPolicyClient{})), - popBrowser: opts.PopBrowserURL, - c2nHandler: opts.C2NHandler, - dialer: opts.Dialer, - dnsCache: dnsCache, - dialPlan: opts.DialPlan, + httpc: httpc, + interceptedDial: interceptedDial, + controlKnobs: opts.ControlKnobs, + getMachinePrivKey: opts.GetMachinePrivateKey, + serverURL: opts.ServerURL, + clock: opts.Clock, + logf: opts.Logf, + persist: opts.Persist.View(), + authKey: opts.AuthKey, + debugFlags: opts.DebugFlags, + netMon: netMon, + health: opts.HealthTracker, + extraRootCAs: opts.ExtraRootCAs, + pinger: opts.Pinger, + polc: cmp.Or(opts.PolicyClient, policyclient.Client(policyclient.NoPolicyClient{})), + popBrowser: opts.PopBrowserURL, + c2nHandler: opts.C2NHandler, + dialer: opts.Dialer, + dnsCache: dnsCache, + dialPlan: opts.DialPlan, } c.discoPubKey = opts.DiscoPublicKey c.closedCtx, c.closeCtx = context.WithCancel(context.Background()) @@ -356,6 +372,38 @@ func NewDirect(opts Options) (*Direct, error) { c.clientVersionPub = eventbus.Publish[tailcfg.ClientVersion](c.busClient) c.autoUpdatePub = eventbus.Publish[AutoUpdate](c.busClient) c.controlTimePub = eventbus.Publish[ControlTime](c.busClient) + discoKeyPub := eventbus.Publish[events.PeerDiscoKeyUpdate](c.busClient) + eventbus.SubscribeFunc(c.busClient, func(update events.DiscoKeyAdvertisement) { + c.logf("controlclient direct: got TSMP disco key advertisement from %v via eventbus", update.Src) + var peerID tailcfg.NodeID + var peerKey key.NodePublic + var ok bool + c.mu.Lock() + sess := c.streamingMapSession + c.mu.Unlock() + if sess != nil { + peerID, peerKey, ok = sess.PeerIDAndKeyByTailscaleIP(update.Src) + } + + if sess != nil && ok { + c.logf("controlclient direct: updating discoKey for %v via mapSession", update.Src) + + // If we update without error, return. If the err indicates that the + // mapSession has gone away, we want to fall back to pushing the key + // further down the chain. + if err := sess.updateDiscoForNode( + peerID, peerKey, update.Key, time.Now(), false); err == nil || + !errors.Is(err, ErrChangeQueueClosed) { + return + } + } + + // We need to push the update further down the chain. Either because we do + // not have a mapSession (we are not connected to control) or because the + // mapSession queue has closed. + c.logf("controlclient direct: updating discoKey for %v via magicsock", update.Src) + discoKeyPub.Publish(events.PeerDiscoKeyUpdate(update)) + }) return c, nil } @@ -383,7 +431,7 @@ func (c *Direct) SetHostinfo(hi *tailcfg.Hostinfo) bool { if hi == nil { panic("nil Hostinfo") } - hi = ptr.To(*hi) + hi = new(*hi) hi.NetInfo = nil c.mu.Lock() defer c.mu.Unlock() @@ -529,6 +577,37 @@ var macOSScreenTime = health.Register(&health.Warnable{ ImpactsConnectivity: true, }) +type rateLimitError struct { + msg string + retryAfter time.Duration +} + +func (e *rateLimitError) Error() string { + return fmt.Sprintf("rate limited: %s (retry after %v)", e.msg, e.retryAfter) +} + +func parseRateLimitError(res *http.Response) *rateLimitError { + msg, _ := io.ReadAll(res.Body) + res.Body.Close() + + ret := &rateLimitError{ + msg: strings.TrimSpace(string(msg)), + } + + v := res.Header.Get("Retry-After") + if i, err := strconv.Atoi(v); err == nil { + ret.retryAfter = time.Duration(i) * time.Second + } else if t, err := http.ParseTime(v); err == nil { + ret.retryAfter = time.Until(t) + } + + // If the server didn't give us a valid Retry-After, default to 10s. + if ret.retryAfter <= 0 || ret.retryAfter > time.Hour { + ret.retryAfter = 5*time.Second + rand.N(5*time.Second) + } + return ret +} + func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, newURL string, nks tkatype.MarshaledSignature, err error) { if c.panicOnUse { panic("tainted client") @@ -723,6 +802,12 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new if err != nil { return regen, opt.URL, nil, fmt.Errorf("register request: %w", err) } + // Handle 429 Too Many Requests with a specific error type that includes the retry-after duration. + if res.StatusCode == 429 { + rle := parseRateLimitError(res) + msg := fmt.Sprintf("node registration rate limited; will retry after %v", rle.retryAfter) + return false, "", nil, vizerror.WrapWithMessage(rle, msg) + } if res.StatusCode != 200 { msg, _ := io.ReadAll(res.Body) res.Body.Close() @@ -829,21 +914,41 @@ func (c *Direct) PollNetMap(ctx context.Context, nu NetmapUpdater) error { return c.sendMapRequest(ctx, true, nu) } +// rememberLastNetmapUpdater is a container that remembers the last netmap +// update it observed. It is used by tests and [NetmapFromMapResponseForDebug]. +// It will report only the first netmap seen. type rememberLastNetmapUpdater struct { - last *netmap.NetworkMap + last *netmap.NetworkMap + lastTSMPKey key.NodePublic + lastTSMPDisco key.DiscoPublic + done chan any } func (nu *rememberLastNetmapUpdater) UpdateFullNetmap(nm *netmap.NetworkMap) { nu.last = nm + select { + case nu.done <- nil: + default: + } +} + +func (nu *rememberLastNetmapUpdater) PatchDiscoKey(key key.NodePublic, disco key.DiscoPublic) { + nu.lastTSMPKey = key + nu.lastTSMPDisco = disco } // FetchNetMapForTest fetches the netmap once. func (c *Direct) FetchNetMapForTest(ctx context.Context) (*netmap.NetworkMap, error) { var nu rememberLastNetmapUpdater + nu.done = make(chan any, 1) err := c.sendMapRequest(ctx, false, &nu) - if err == nil && nu.last == nil { + if err != nil { + return nil, err + } + if nu.last == nil { return nil, errors.New("[unexpected] sendMapRequest success without callback") } + <-nu.done return nu.last, err } @@ -862,6 +967,18 @@ func (c *Direct) SetDiscoPublicKey(key key.DiscoPublic) { c.discoPubKey = key } +// SetIPForwardingBroken updates the IP forwarding broken state. +// It reports whether the value changed. +func (c *Direct) SetIPForwardingBroken(v bool) bool { + c.mu.Lock() + defer c.mu.Unlock() + if c.ipForwardBroken == v { + return false + } + c.ipForwardBroken = v + return true +} + // ClientID returns the controlClientID of the controlClient. func (c *Direct) ClientID() int64 { return c.controlClientID @@ -992,10 +1109,6 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap } var extraDebugFlags []string - if buildfeatures.HasAdvertiseRoutes && hi != nil && c.netMon != nil && !c.skipIPForwardingCheck && - ipForwardingBroken(hi.RoutableIPs, c.netMon.InterfaceState()) { - extraDebugFlags = append(extraDebugFlags, "warn-ip-forwarding-off") - } if c.health.RouterHealth() != nil { extraDebugFlags = append(extraDebugFlags, "warn-router-unhealthy") } @@ -1080,8 +1193,22 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap return nil } + if isStreaming && c.streamingMapSession != nil { + panic("mapSession is already set") + } + sess := newMapSession(persist.PrivateNodeKey(), nu, c.controlKnobs) - defer sess.Close() + if isStreaming { + c.streamingMapSession = sess + defer func() { + sess.Close() + c.mu.Lock() + c.streamingMapSession = nil + c.mu.Unlock() + }() + } else { + defer sess.Close() + } sess.cancel = cancel sess.logf = c.logf sess.vlogf = vlogf @@ -1235,7 +1362,7 @@ func NetmapFromMapResponseForDebug(ctx context.Context, pr persist.PersistView, return nil, errors.New("PersistView invalid") } - nu := &rememberLastNetmapUpdater{} + nu := &rememberLastNetmapUpdater{done: make(chan any, 1)} sess := newMapSession(pr.PrivateNodeKey(), nu, nil) defer sess.Close() @@ -1243,6 +1370,7 @@ func NetmapFromMapResponseForDebug(ctx context.Context, pr persist.PersistView, return nil, fmt.Errorf("HandleNonKeepAliveMapResponse: %w", err) } + <-nu.done return sess.netmap(), nil } @@ -1303,10 +1431,10 @@ var jsonEscapedZero = []byte(`\u0000`) const justKeepAliveStr = `{"KeepAlive":true}` // decodeMsg is responsible for uncompressing msg and unmarshaling into v. -func (sess *mapSession) decodeMsg(compressedMsg []byte, v *tailcfg.MapResponse) error { +func (ms *mapSession) decodeMsg(compressedMsg []byte, v *tailcfg.MapResponse) error { // Fast path for common case of keep-alive message. // See tailscale/tailscale#17343. - if sess.keepAliveZ != nil && bytes.Equal(compressedMsg, sess.keepAliveZ) { + if ms.keepAliveZ != nil && bytes.Equal(compressedMsg, ms.keepAliveZ) { v.KeepAlive = true return nil } @@ -1315,7 +1443,7 @@ func (sess *mapSession) decodeMsg(compressedMsg []byte, v *tailcfg.MapResponse) if err != nil { return err } - sess.ztdDecodesForTest++ + ms.ztdDecodesForTest++ if DevKnob.DumpNetMaps() { var buf bytes.Buffer @@ -1330,7 +1458,7 @@ func (sess *mapSession) decodeMsg(compressedMsg []byte, v *tailcfg.MapResponse) return fmt.Errorf("response: %v", err) } if v.KeepAlive && string(b) == justKeepAliveStr { - sess.keepAliveZ = compressedMsg + ms.keepAliveZ = compressedMsg } return nil } @@ -1414,24 +1542,6 @@ func initDevKnob() devKnobs { var clock tstime.Clock = tstime.StdClock{} -// ipForwardingBroken reports whether the system's IP forwarding is disabled -// and will definitely not work for the routes provided. -// -// It should not return false positives. -// -// TODO(bradfitz): Change controlclient.Options.SkipIPForwardingCheck into a -// func([]netip.Prefix) error signature instead. -func ipForwardingBroken(routes []netip.Prefix, state *netmon.State) bool { - warn, err := netutil.CheckIPForwarding(routes, state) - if err != nil { - // Oh well, we tried. This is just for debugging. - // We don't want false positives. - // TODO: maybe we want a different warning for inability to check? - return false - } - return warn != nil -} - // isUniquePingRequest reports whether pr contains a new PingRequest.URL // not already handled, noting its value when returning true. func (c *Direct) isUniquePingRequest(pr *tailcfg.PingRequest) bool { @@ -1485,7 +1595,7 @@ func (c *Direct) answerPing(pr *tailcfg.PingRequest) { } return } - for _, t := range strings.Split(pr.Types, ",") { + for t := range strings.SplitSeq(pr.Types, ",") { switch pt := tailcfg.PingType(t); pt { case tailcfg.PingTSMP, tailcfg.PingDisco, tailcfg.PingICMP, tailcfg.PingPeerAPI: go doPingerPing(c.logf, httpc, pr, c.pinger, pt) @@ -1571,6 +1681,7 @@ func (c *Direct) getNoiseClient() (*ts2021.Client, error) { Logf: c.logf, NetMon: c.netMon, HealthTracker: c.health, + ExtraRootCAs: c.extraRootCAs, DialPlan: dp, }) if err != nil { diff --git a/control/controlclient/direct_test.go b/control/controlclient/direct_test.go index d10b346ae39a7..98741482f1d02 100644 --- a/control/controlclient/direct_test.go +++ b/control/controlclient/direct_test.go @@ -5,9 +5,11 @@ package controlclient import ( "encoding/json" + "errors" "net/http" "net/http/httptest" "net/netip" + "strings" "testing" "time" @@ -126,6 +128,109 @@ func fakeEndpoints(ports ...uint16) (ret []tailcfg.Endpoint) { return } +func TestParseRateLimitError(t *testing.T) { + tests := []struct { + name string + statusCode int + body string + retryAfter string // Retry-After header value + wantMsg string + wantMin time.Duration // minimum expected retryAfter + wantMax time.Duration // maximum expected retryAfter + }{ + { + name: "retry-after-seconds", + statusCode: 429, + body: "too many requests", + retryAfter: "30", + wantMsg: "too many requests", + wantMin: 30 * time.Second, + wantMax: 30 * time.Second, + }, + { + name: "no-retry-after-header", + statusCode: 429, + body: "slow down", + retryAfter: "", + wantMsg: "slow down", + wantMin: 5 * time.Second, + wantMax: 10 * time.Second, + }, + { + name: "unparseable-retry-after", + statusCode: 429, + body: "rate limited", + retryAfter: "not-a-number", + wantMsg: "rate limited", + wantMin: 5 * time.Second, + wantMax: 10 * time.Second, + }, + { + name: "empty-body", + statusCode: 429, + body: "", + retryAfter: "5", + wantMsg: "", + wantMin: 5 * time.Second, + wantMax: 5 * time.Second, + }, + { + name: "body-with-whitespace", + statusCode: 429, + body: " too many requests \n", + retryAfter: "10", + wantMsg: "too many requests", + wantMin: 10 * time.Second, + wantMax: 10 * time.Second, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rec := httptest.NewRecorder() + if tt.retryAfter != "" { + rec.Header().Set("Retry-After", tt.retryAfter) + } + rec.WriteHeader(tt.statusCode) + rec.Body.WriteString(tt.body) + res := rec.Result() + + err := parseRateLimitError(res) + if err == nil { + t.Fatal("expected non-nil error") + } + + var rle *rateLimitError + if !errors.As(err, &rle) { + t.Fatalf("error is not a *rateLimitError: %T", err) + } + if rle.msg != tt.wantMsg { + t.Errorf("msg = %q, want %q", rle.msg, tt.wantMsg) + } + if rle.retryAfter < tt.wantMin || rle.retryAfter > tt.wantMax { + t.Errorf("retryAfter = %v, want between %v and %v", rle.retryAfter, tt.wantMin, tt.wantMax) + } + + // Verify the Error() string contains useful information. + errStr := err.Error() + if !strings.Contains(errStr, "rate limited") { + t.Errorf("Error() = %q, want it to contain 'rate limited'", errStr) + } + }) + } +} + +func TestRateLimitErrorIsError(t *testing.T) { + err := &rateLimitError{msg: "test", retryAfter: 5 * time.Second} + var target *rateLimitError + if !errors.As(err, &target) { + t.Fatal("errors.As should match *rateLimitError") + } + if target.retryAfter != 5*time.Second { + t.Errorf("retryAfter = %v, want 5s", target.retryAfter) + } +} + func TestTsmpPing(t *testing.T) { hi := hostinfo.New() ni := tailcfg.NetInfo{LinkType: "wired"} diff --git a/control/controlclient/map.go b/control/controlclient/map.go index 18bd420ebaae3..9af3a75bd51b9 100644 --- a/control/controlclient/map.go +++ b/control/controlclient/map.go @@ -9,9 +9,11 @@ import ( "crypto/sha256" "encoding/hex" "encoding/json" + "errors" "io" "maps" "net" + "net/netip" "reflect" "runtime" "runtime/debug" @@ -28,7 +30,6 @@ import ( "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/netmap" - "tailscale.com/types/ptr" "tailscale.com/types/views" "tailscale.com/util/clientmetric" "tailscale.com/util/mak" @@ -37,6 +38,11 @@ import ( "tailscale.com/wgengine/filter" ) +type responseWithSource struct { + response *tailcfg.MapResponse + viaTSMP bool +} + // mapSession holds the state over a long-polled "map" request to the // control plane. // @@ -81,7 +87,6 @@ type mapSession struct { lastPrintMap time.Time lastNode tailcfg.NodeView lastCapSet set.Set[tailcfg.NodeCapability] - peers map[tailcfg.NodeID]tailcfg.NodeView lastDNSConfig *tailcfg.DNSConfig lastDERPMap *tailcfg.DERPMap lastUserProfile map[tailcfg.UserID]tailcfg.UserProfileView @@ -97,6 +102,14 @@ type mapSession struct { lastPopBrowserURL string lastTKAInfo *tailcfg.TKAInfo lastNetmapSummary string // from NetworkMap.VeryConcise + cqmu sync.Mutex + changeQueue chan responseWithSource + changeQueueClosed bool + processQueue sync.WaitGroup + + // mu protects the peers map. + peersMu sync.RWMutex + peers map[tailcfg.NodeID]tailcfg.NodeView } // newMapSession returns a mostly unconfigured new mapSession. @@ -119,11 +132,48 @@ func newMapSession(privateNodeKey key.NodePrivate, nu NetmapUpdater, controlKnob cancel: func() {}, onDebug: func(context.Context, *tailcfg.Debug) error { return nil }, onSelfNodeChanged: func(*netmap.NetworkMap) {}, + changeQueue: make(chan responseWithSource), + changeQueueClosed: false, } ms.sessionAliveCtx, ms.sessionAliveCtxClose = context.WithCancel(context.Background()) + ms.processQueue.Add(1) + go ms.run() return ms } +// run starts the mapSession processing a queue of tailcfg.MapResponse one by +// one until close() is called on the mapSession. +// When the mapSession is closed, the remaining queue is locked and processed +// before the mapSession is done processing. +func (ms *mapSession) run() { + defer ms.processQueue.Done() + + for { + select { + case change := <-ms.changeQueue: + ms.handleNonKeepAliveMapResponse(ms.sessionAliveCtx, change.response, change.viaTSMP) + case <-ms.sessionAliveCtx.Done(): + // Drain any remaining items in the queue before exiting. + // Lock the queue during this time to avoid updates through other channels + // to be overwritten. This is especially relevant for calls to + // updateDiscoForNode. + ms.cqmu.Lock() + ms.changeQueueClosed = true + ms.cqmu.Unlock() + for { + select { + case change := <-ms.changeQueue: + ms.handleNonKeepAliveMapResponse(ms.sessionAliveCtx, change.response, change.viaTSMP) + default: + // Queue is empty, close it and exit + close(ms.changeQueue) + return + } + } + } + } +} + // occasionallyPrintSummary logs summary at most once very 5 minutes. The // summary is the Netmap.VeryConcise result from the last received map response. func (ms *mapSession) occasionallyPrintSummary(summary string) { @@ -144,6 +194,34 @@ func (ms *mapSession) clock() tstime.Clock { func (ms *mapSession) Close() { ms.sessionAliveCtxClose() + ms.processQueue.Wait() +} + +var ErrChangeQueueClosed = errors.New("change queue closed") + +func (ms *mapSession) updateDiscoForNode(id tailcfg.NodeID, key key.NodePublic, discoKey key.DiscoPublic, lastSeen time.Time, online bool) error { + ms.cqmu.Lock() + + if ms.changeQueueClosed { + ms.cqmu.Unlock() + ms.processQueue.Wait() + return ErrChangeQueueClosed + } + defer ms.cqmu.Unlock() + + resp := responseWithSource{ + response: &tailcfg.MapResponse{ + PeersChangedPatch: []*tailcfg.PeerChange{{ + NodeID: id, + Key: &key, + LastSeen: &lastSeen, + Online: &online, + DiscoKey: &discoKey, + }}, + }, + viaTSMP: true, + } + return ms.addRespToQueue(resp) } // HandleNonKeepAliveMapResponse handles a non-KeepAlive MapResponse (full or @@ -152,8 +230,8 @@ func (ms *mapSession) Close() { // All fields that are valid on a KeepAlive MapResponse have already been // handled. // -// TODO(bradfitz): make this handle all fields later. For now (2023-08-20) this -// is [re]factoring progress enough. +// Debug messages are handled first, followed by pushing the response onto a +// queue for new updates handled sequentially. func (ms *mapSession) HandleNonKeepAliveMapResponse(ctx context.Context, resp *tailcfg.MapResponse) error { if debug := resp.Debug; debug != nil { if err := ms.onDebug(ctx, debug); err != nil { @@ -161,6 +239,42 @@ func (ms *mapSession) HandleNonKeepAliveMapResponse(ctx context.Context, resp *t } } + ms.cqmu.Lock() + + if ms.changeQueueClosed { + ms.cqmu.Unlock() + ms.processQueue.Wait() + return ErrChangeQueueClosed + } + + defer ms.cqmu.Unlock() + + change := responseWithSource{ + response: resp, + viaTSMP: false, + } + + return ms.addRespToQueue(change) +} + +func (ms *mapSession) addRespToQueue(resp responseWithSource) error { + select { + case ms.changeQueue <- resp: + return nil + case <-ms.sessionAliveCtx.Done(): + return ErrChangeQueueClosed + } +} + +// handleNonKeepAliveMapResponse handles a non-KeepAlive MapResponse (full or +// incremental). +// +// All fields that are valid on a KeepAlive MapResponse have already been +// handled. +// +// TODO(bradfitz): make this handle all fields later. For now (2023-08-20) this +// is [re]factoring progress enough. +func (ms *mapSession) handleNonKeepAliveMapResponse(ctx context.Context, resp *tailcfg.MapResponse, viaTSMP bool) error { if DevKnob.StripEndpoints() { for _, p := range resp.Peers { p.Endpoints = nil @@ -200,8 +314,23 @@ func (ms *mapSession) HandleNonKeepAliveMapResponse(ctx context.Context, resp *t ms.patchifyPeersChanged(resp) + ms.removeUnwantedDiscoUpdates(resp, viaTSMP) + + // TSMP learned key was rejected, no need to do any more work in the engine. + if viaTSMP && len(resp.PeersChangedPatch) == 0 { + return nil + } + ms.removeUnwantedDiscoUpdatesFromFullNetmapUpdate(resp) + ms.updateStateFromResponse(resp) + // If source was learned via TSMP, the updated disco key need to be marked in + // userspaceEngine as an update that should not reconfigure the wireguard + // connection. + if viaTSMP { + ms.tryMarkDiscoAsLearnedFromTSMP(resp) + } + if ms.tryHandleIncrementally(resp) { ms.occasionallyPrintSummary(ms.lastNetmapSummary) return nil @@ -230,6 +359,21 @@ func (ms *mapSession) HandleNonKeepAliveMapResponse(ctx context.Context, resp *t return nil } +func (ms *mapSession) tryMarkDiscoAsLearnedFromTSMP(res *tailcfg.MapResponse) { + dun, ok := ms.netmapUpdater.(patchDiscoKeyer) + if !ok { + return + } + + // In reality we should never really have more than one change here over TSMP. + for _, change := range res.PeersChangedPatch { + if change == nil || change.DiscoKey == nil || change.Key == nil { + continue + } + dun.PatchDiscoKey(*change.Key, *change.DiscoKey) + } +} + // upgradeNode upgrades Node fields from the server into the modern forms // not using deprecated fields. func upgradeNode(n *tailcfg.Node) { @@ -282,6 +426,125 @@ type updateStats struct { changed int } +// removeUnwantedDiscoUpdates goes over the patchified updates and reject items +// where the node is offline and has last been seen before the recorded last seen. +func (ms *mapSession) removeUnwantedDiscoUpdates(resp *tailcfg.MapResponse, viaTSMP bool) { + ms.peersMu.RLock() + defer ms.peersMu.RUnlock() + + acceptedDiscoUpdates := resp.PeersChangedPatch[:0] + + for _, change := range resp.PeersChangedPatch { + // Accept if: + // - DiscoKey is nil and did not change. + // - Fields we rely on for rejection is missing. + if change.DiscoKey == nil || change.Online == nil || change.LastSeen == nil { + acceptedDiscoUpdates = append(acceptedDiscoUpdates, change) + continue + } + + existingNode, ok := ms.peers[change.NodeID] + // Accept if: + // - Cannot find the peer, don't have enough data. + if !ok { + acceptedDiscoUpdates = append(acceptedDiscoUpdates, change) + continue + } + + // Reject if: + // - key was learned via tsmp AND, + // - existing node is online AND, + // - key did not change. + // Here to avoid a deeper reconfig in the case where we get a TSMP key + // exchange while that node is already in a connected state (from the view + // of the control plane). This is meant to keep the node stable, avoiding a + // reconfiguration of the node deeper down in the engine. + // With this, we are avoiding updating the LastSeen and Online fields from + // TSMP updates when that is not relevant, overall making the connection + // state change less, and updating the engine less. + if viaTSMP && existingNode.Online().Get() && + *change.DiscoKey == existingNode.DiscoKey() { + continue + } + + // Accept if: + // - Node is online. + if *change.Online { + acceptedDiscoUpdates = append(acceptedDiscoUpdates, change) + continue + } + + // Accept if: + // - if we don't have a last seen to compare against on the existing node. + // - OR lastSeen moved forward in time. + if existingLastSeen, ok := existingNode.LastSeen().GetOk(); !ok || + change.LastSeen.After(existingLastSeen) { + acceptedDiscoUpdates = append(acceptedDiscoUpdates, change) + } + } + + resp.PeersChangedPatch = acceptedDiscoUpdates +} + +// removeUnwantedDiscoUpdatesFromFullNetmapUpdate makes a pass over the full +// set of peers in an update, usually only received when getting a full netmap +// from control at startup. If the pass finds a peer with a disco key where the +// local netmap has a newer key learned via TSMP, overwrite the update with the +// key from TSMP. +func (ms *mapSession) removeUnwantedDiscoUpdatesFromFullNetmapUpdate(resp *tailcfg.MapResponse) { + ms.peersMu.RLock() + defer ms.peersMu.RUnlock() + + if len(resp.Peers) == 0 { + return + } + for _, peer := range resp.Peers { + if peer.DiscoKey.IsZero() { + continue + } + + // Accept if: + // - peer is new + existingNode, ok := ms.peers[peer.ID] + if !ok { + continue + } + + // Accept if: + // - disco key has not changed + if existingNode.DiscoKey() == peer.DiscoKey { + continue + } + + // Accept if: + // - key has changed but peer is online + if peer.Online != nil && *peer.Online { + continue + } + + // Accept if: + // - there's no last seen on the existing node + existingLastSeen, ok := existingNode.LastSeen().GetOk() + if !ok { + continue + } + + // Accept if: + // - last seen on on control is higher + if peer.LastSeen != nil && peer.LastSeen.After(existingLastSeen) { + continue + } + + // Overwrite the key and last seen in the full netmap update. + peer.DiscoKey = existingNode.DiscoKey() + if t, ok := existingNode.LastSeen().GetOk(); ok { + peer.LastSeen = new(t) + } else { + peer.LastSeen = nil + } + } +} + // updateStateFromResponse updates ms from res. It takes ownership of res. func (ms *mapSession) updateStateFromResponse(resp *tailcfg.MapResponse) { ms.updatePeersStateFromResponse(resp) @@ -455,6 +718,9 @@ var ( // updatePeersStateFromResponseres updates ms.peers from resp. // It takes ownership of resp. func (ms *mapSession) updatePeersStateFromResponse(resp *tailcfg.MapResponse) (stats updateStats) { + ms.peersMu.Lock() + defer ms.peersMu.Unlock() + if ms.peers == nil { ms.peers = make(map[tailcfg.NodeID]tailcfg.NodeView) } @@ -504,7 +770,7 @@ func (ms *mapSession) updatePeersStateFromResponse(resp *tailcfg.MapResponse) (s if vp, ok := ms.peers[nodeID]; ok { mut := vp.AsStruct() if seen { - mut.LastSeen = ptr.To(clock.Now()) + mut.LastSeen = new(clock.Now()) } else { mut.LastSeen = nil } @@ -516,7 +782,7 @@ func (ms *mapSession) updatePeersStateFromResponse(resp *tailcfg.MapResponse) (s for nodeID, online := range resp.OnlineChange { if vp, ok := ms.peers[nodeID]; ok { mut := vp.AsStruct() - mut.Online = ptr.To(online) + mut.Online = new(online) ms.peers[nodeID] = mut.View() stats.changed++ } @@ -550,11 +816,11 @@ func (ms *mapSession) updatePeersStateFromResponse(resp *tailcfg.MapResponse) (s patchDiscoKey.Add(1) } if v := pc.Online; v != nil { - mut.Online = ptr.To(*v) + mut.Online = new(*v) patchOnline.Add(1) } if v := pc.LastSeen; v != nil { - mut.LastSeen = ptr.To(*v) + mut.LastSeen = new(*v) patchLastSeen.Add(1) } if v := pc.KeyExpiry; v != nil { @@ -589,13 +855,22 @@ func (ms *mapSession) addUserProfile(nm *netmap.NetworkMap, userID tailcfg.UserI } var debugPatchifyPeer = envknob.RegisterBool("TS_DEBUG_PATCHIFY_PEER") +var debugPatchifyPeerMiss = envknob.RegisterBool("TS_DEBUG_PATCHIFY_PEER_MISS") + +// patchifyMissOnFalse, if non-nil, is called with the field name when +// patchifyPeer fails. It is set by an init func in map_debug.go. +var patchifyMissOnFalse func(string) // patchifyPeersChanged mutates resp to promote PeersChanged entries to PeersChangedPatch // when possible. func (ms *mapSession) patchifyPeersChanged(resp *tailcfg.MapResponse) { + var onFalse func(string) + if debugPatchifyPeerMiss() { + onFalse = patchifyMissOnFalse + } filtered := resp.PeersChanged[:0] for _, n := range resp.PeersChanged { - if p, ok := ms.patchifyPeer(n); ok { + if p, ok := ms.patchifyPeer(n, onFalse); ok { patchifiedPeer.Add(1) if debugPatchifyPeer() { patchj, _ := json.Marshal(p) @@ -618,12 +893,12 @@ func (ms *mapSession) patchifyPeersChanged(resp *tailcfg.MapResponse) { var nodeFields = sync.OnceValue(getNodeFields) -// getNodeFields returns the fails of tailcfg.Node. +// getNodeFields returns the fields of tailcfg.Node. func getNodeFields() []string { rt := reflect.TypeFor[tailcfg.Node]() - ret := make([]string, rt.NumField()) - for i := range rt.NumField() { - ret[i] = rt.Field(i).Name + ret := make([]string, 0, rt.NumField()) + for f := range rt.Fields() { + ret = append(ret, f.Name) } return ret } @@ -633,18 +908,27 @@ func getNodeFields() []string { // // It returns ok=false if a patch can't be made, (V, ok) on a delta, or (nil, // true) if all the fields were identical (a zero change). -func (ms *mapSession) patchifyPeer(n *tailcfg.Node) (_ *tailcfg.PeerChange, ok bool) { +func (ms *mapSession) patchifyPeer(n *tailcfg.Node, onFalse func(string)) (_ *tailcfg.PeerChange, ok bool) { + ms.peersMu.RLock() + defer ms.peersMu.RUnlock() + was, ok := ms.peers[n.ID] if !ok { + if onFalse != nil { + onFalse("peer_not_found") + } return nil, false } - return peerChangeDiff(was, n) + return peerChangeDiff(was, n, onFalse) } // peerChangeDiff returns the difference from 'was' to 'n', if possible. // // It returns (nil, true) if the fields were identical. -func peerChangeDiff(was tailcfg.NodeView, n *tailcfg.Node) (_ *tailcfg.PeerChange, ok bool) { +func peerChangeDiff(was tailcfg.NodeView, n *tailcfg.Node, onFalse func(string)) (_ *tailcfg.PeerChange, ok bool) { + if onFalse == nil { + onFalse = func(string) {} + } var ret *tailcfg.PeerChange pc := func() *tailcfg.PeerChange { if ret == nil { @@ -668,31 +952,36 @@ func peerChangeDiff(was tailcfg.NodeView, n *tailcfg.Node) (_ *tailcfg.PeerChang // And it was never sent by any known control server. case "ID": if was.ID() != n.ID { + onFalse(field) return nil, false } case "StableID": if was.StableID() != n.StableID { + onFalse(field) return nil, false } case "Name": if was.Name() != n.Name { + onFalse(field) return nil, false } case "User": if was.User() != n.User { + onFalse(field) return nil, false } case "Sharer": if was.Sharer() != n.Sharer { + onFalse(field) return nil, false } case "Key": if was.Key() != n.Key { - pc().Key = ptr.To(n.Key) + pc().Key = new(n.Key) } case "KeyExpiry": if !was.KeyExpiry().Equal(n.KeyExpiry) { - pc().KeyExpiry = ptr.To(n.KeyExpiry) + pc().KeyExpiry = new(n.KeyExpiry) } case "KeySignature": if !was.KeySignature().Equal(n.KeySignature) { @@ -700,18 +989,21 @@ func peerChangeDiff(was tailcfg.NodeView, n *tailcfg.Node) (_ *tailcfg.PeerChang } case "Machine": if was.Machine() != n.Machine { + onFalse(field) return nil, false } case "DiscoKey": if was.DiscoKey() != n.DiscoKey { - pc().DiscoKey = ptr.To(n.DiscoKey) + pc().DiscoKey = new(n.DiscoKey) } case "Addresses": if !views.SliceEqual(was.Addresses(), views.SliceOf(n.Addresses)) { + onFalse(field) return nil, false } case "AllowedIPs": if !views.SliceEqual(was.AllowedIPs(), views.SliceOf(n.AllowedIPs)) { + onFalse(field) return nil, false } case "Endpoints": @@ -731,13 +1023,16 @@ func peerChangeDiff(was tailcfg.NodeView, n *tailcfg.Node) (_ *tailcfg.PeerChang continue } if !was.Hostinfo().Valid() || !n.Hostinfo.Valid() { + onFalse(field) return nil, false } if !was.Hostinfo().Equal(n.Hostinfo) { + onFalse(field) return nil, false } case "Created": if !was.Created().Equal(n.Created) { + onFalse(field) return nil, false } case "Cap": @@ -765,38 +1060,45 @@ func peerChangeDiff(was tailcfg.NodeView, n *tailcfg.Node) (_ *tailcfg.PeerChang } case "Tags": if !views.SliceEqual(was.Tags(), views.SliceOf(n.Tags)) { + onFalse(field) return nil, false } case "PrimaryRoutes": if !views.SliceEqual(was.PrimaryRoutes(), views.SliceOf(n.PrimaryRoutes)) { + onFalse(field) return nil, false } case "Online": if wasOnline, ok := was.Online().GetOk(); ok && n.Online != nil && *n.Online != wasOnline { - pc().Online = ptr.To(*n.Online) + pc().Online = new(*n.Online) } case "LastSeen": if wasSeen, ok := was.LastSeen().GetOk(); ok && n.LastSeen != nil && !wasSeen.Equal(*n.LastSeen) { - pc().LastSeen = ptr.To(*n.LastSeen) + pc().LastSeen = new(*n.LastSeen) } case "MachineAuthorized": if was.MachineAuthorized() != n.MachineAuthorized { + onFalse(field) return nil, false } case "UnsignedPeerAPIOnly": if was.UnsignedPeerAPIOnly() != n.UnsignedPeerAPIOnly { + onFalse(field) return nil, false } case "IsWireGuardOnly": if was.IsWireGuardOnly() != n.IsWireGuardOnly { + onFalse(field) return nil, false } case "IsJailed": if was.IsJailed() != n.IsJailed { + onFalse(field) return nil, false } case "Expired": if was.Expired() != n.Expired { + onFalse(field) return nil, false } case "SelfNodeV4MasqAddrForThisPeer": @@ -805,6 +1107,7 @@ func peerChangeDiff(was tailcfg.NodeView, n *tailcfg.Node) (_ *tailcfg.PeerChang continue } if va, ok := va.GetOk(); !ok || vb == nil || va != *vb { + onFalse(field) return nil, false } case "SelfNodeV6MasqAddrForThisPeer": @@ -813,17 +1116,20 @@ func peerChangeDiff(was tailcfg.NodeView, n *tailcfg.Node) (_ *tailcfg.PeerChang continue } if va, ok := va.GetOk(); !ok || vb == nil || va != *vb { + onFalse(field) return nil, false } case "ExitNodeDNSResolvers": va, vb := was.ExitNodeDNSResolvers(), views.SliceOfViews(n.ExitNodeDNSResolvers) if va.Len() != vb.Len() { + onFalse(field) return nil, false } for i := range va.Len() { if !va.At(i).Equal(vb.At(i)) { + onFalse(field) return nil, false } } @@ -836,7 +1142,28 @@ func peerChangeDiff(was tailcfg.NodeView, n *tailcfg.Node) (_ *tailcfg.PeerChang return ret, true } +// PeerIDAndKeyByTailscaleIP returns the node ID and node Key from the peers +// map without touching the netmap itself. The implementation mirrors the +// implementation of [netmap.PeerByTailscaleIP]. +func (ms *mapSession) PeerIDAndKeyByTailscaleIP(ip netip.Addr) (tailcfg.NodeID, key.NodePublic, bool) { + ms.peersMu.RLock() + defer ms.peersMu.RUnlock() + for _, n := range ms.peers { + ad := n.Addresses() + for i := range ad.Len() { + a := ad.At(i) + if a.Addr() == ip { + return n.ID(), n.Key(), true + } + } + } + return 0, key.NodePublic{}, false +} + func (ms *mapSession) sortedPeers() []tailcfg.NodeView { + ms.peersMu.RLock() + defer ms.peersMu.RUnlock() + ret := slicesx.MapValues(ms.peers) slices.SortFunc(ret, func(a, b tailcfg.NodeView) int { return cmp.Compare(a.ID(), b.ID()) diff --git a/control/controlclient/map_debug.go b/control/controlclient/map_debug.go new file mode 100644 index 0000000000000..2d6012211cba7 --- /dev/null +++ b/control/controlclient/map_debug.go @@ -0,0 +1,16 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_debug + +package controlclient + +import "tailscale.com/metrics" + +var patchifyMissStats = metrics.NewLabelMap("counter_patchify_miss", "why") + +func init() { + patchifyMissOnFalse = func(field string) { + patchifyMissStats.Add(field, 1) + } +} diff --git a/control/controlclient/map_test.go b/control/controlclient/map_test.go index 11d4593f03fae..1c4dc6d781582 100644 --- a/control/controlclient/map_test.go +++ b/control/controlclient/map_test.go @@ -14,6 +14,7 @@ import ( "strings" "sync/atomic" "testing" + "testing/synctest" "time" "github.com/google/go-cmp/cmp" @@ -30,11 +31,12 @@ import ( "tailscale.com/types/logger" "tailscale.com/types/netmap" "tailscale.com/types/persist" - "tailscale.com/types/ptr" "tailscale.com/util/eventbus/eventbustest" "tailscale.com/util/mak" "tailscale.com/util/must" + "tailscale.com/util/usermetric" "tailscale.com/util/zstdframe" + "tailscale.com/wgengine" ) func eps(s ...string) []netip.AddrPort { @@ -250,7 +252,7 @@ func TestUpdatePeersStateFromResponse(t *testing.T) { mapRes: &tailcfg.MapResponse{ PeersChangedPatch: []*tailcfg.PeerChange{{ NodeID: 1, - Key: ptr.To(key.NodePublicFromRaw32(mem.B(append(make([]byte, 31), 'A')))), + Key: new(key.NodePublicFromRaw32(mem.B(append(make([]byte, 31), 'A')))), }}, }, want: peers(&tailcfg.Node{ ID: 1, @@ -281,7 +283,7 @@ func TestUpdatePeersStateFromResponse(t *testing.T) { mapRes: &tailcfg.MapResponse{ PeersChangedPatch: []*tailcfg.PeerChange{{ NodeID: 1, - DiscoKey: ptr.To(key.DiscoPublicFromRaw32(mem.B(append(make([]byte, 31), 'A')))), + DiscoKey: new(key.DiscoPublicFromRaw32(mem.B(append(make([]byte, 31), 'A')))), }}, }, want: peers(&tailcfg.Node{ @@ -297,13 +299,13 @@ func TestUpdatePeersStateFromResponse(t *testing.T) { mapRes: &tailcfg.MapResponse{ PeersChangedPatch: []*tailcfg.PeerChange{{ NodeID: 1, - Online: ptr.To(true), + Online: new(true), }}, }, want: peers(&tailcfg.Node{ ID: 1, Name: "foo", - Online: ptr.To(true), + Online: new(true), }), wantStats: updateStats{changed: 1}, }, @@ -313,13 +315,13 @@ func TestUpdatePeersStateFromResponse(t *testing.T) { mapRes: &tailcfg.MapResponse{ PeersChangedPatch: []*tailcfg.PeerChange{{ NodeID: 1, - LastSeen: ptr.To(time.Unix(123, 0).UTC()), + LastSeen: new(time.Unix(123, 0).UTC()), }}, }, want: peers(&tailcfg.Node{ ID: 1, Name: "foo", - LastSeen: ptr.To(time.Unix(123, 0).UTC()), + LastSeen: new(time.Unix(123, 0).UTC()), }), wantStats: updateStats{changed: 1}, }, @@ -329,7 +331,7 @@ func TestUpdatePeersStateFromResponse(t *testing.T) { mapRes: &tailcfg.MapResponse{ PeersChangedPatch: []*tailcfg.PeerChange{{ NodeID: 1, - KeyExpiry: ptr.To(time.Unix(123, 0).UTC()), + KeyExpiry: new(time.Unix(123, 0).UTC()), }}, }, want: peers(&tailcfg.Node{ @@ -624,6 +626,415 @@ func TestNetmapForResponse(t *testing.T) { }) } +func TestUpdateDiscoForNode(t *testing.T) { + tests := []struct { + name string + initialOnline bool + initialLastSeen time.Time + updateDiscoKey bool + updateOnline bool + updateLastSeen time.Time + wantUpdate bool + wantKeyChanged bool + }{ + { + name: "newer_key_not_online", + initialOnline: true, + initialLastSeen: time.Unix(1, 0), + updateDiscoKey: true, + updateOnline: false, + updateLastSeen: time.Now(), + wantUpdate: true, + wantKeyChanged: true, + }, + { + name: "newer_key_online", + initialOnline: true, + initialLastSeen: time.Unix(1, 0), + updateDiscoKey: true, + updateOnline: true, + updateLastSeen: time.Now(), + wantUpdate: true, + wantKeyChanged: true, + }, + { + name: "older_key_not_online", + initialOnline: false, + initialLastSeen: time.Now(), + updateDiscoKey: true, + updateOnline: false, + updateLastSeen: time.Unix(1, 0), + wantUpdate: false, + wantKeyChanged: false, + }, + { + name: "older_key_online", + initialOnline: false, + initialLastSeen: time.Now(), + updateDiscoKey: true, + updateOnline: true, + updateLastSeen: time.Unix(1, 0), + wantUpdate: true, + wantKeyChanged: true, + }, + { + name: "same_newer_key_not_online", + initialOnline: true, + initialLastSeen: time.Unix(1, 0), + updateDiscoKey: false, + updateOnline: false, + updateLastSeen: time.Now(), + wantUpdate: false, + wantKeyChanged: false, + }, + { + name: "same_newer_key_online", + initialOnline: true, + initialLastSeen: time.Unix(1, 0), + updateDiscoKey: false, + updateOnline: true, + updateLastSeen: time.Now(), + wantUpdate: false, + wantKeyChanged: false, + }, + { + name: "same_older_key_not_online", + initialOnline: false, + initialLastSeen: time.Now(), + updateDiscoKey: false, + updateOnline: false, + updateLastSeen: time.Unix(1, 0), + wantUpdate: false, + wantKeyChanged: false, + }, + { + name: "same_older_key_online", + initialOnline: false, + initialLastSeen: time.Now(), + updateDiscoKey: false, + updateOnline: true, + updateLastSeen: time.Unix(1, 0), + wantUpdate: true, + wantKeyChanged: false, + }, + { + name: "no_initial_last_seen", + initialOnline: false, + updateDiscoKey: true, + updateOnline: false, + updateLastSeen: time.Now(), + wantUpdate: true, + wantKeyChanged: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + synctest.Test(t, func(*testing.T) { + nu := &rememberLastNetmapUpdater{ + done: make(chan any, 1), + } + ms := newTestMapSession(t, nu) + defer ms.Close() + + oldKey := key.NewDisco() + + // Insert existing node + node := tailcfg.Node{ + ID: 1, + Key: key.NewNode().Public(), + DiscoKey: oldKey.Public(), + Online: &tt.initialOnline, + } + if !tt.initialLastSeen.IsZero() { + node.LastSeen = &tt.initialLastSeen + } + + if nm := ms.netmapForResponse(&tailcfg.MapResponse{ + Peers: []*tailcfg.Node{&node}, + }); len(nm.Peers) != 1 { + t.Fatalf("node not inserted") + } + + newKey := oldKey.Public() + if tt.updateDiscoKey { + newKey = key.NewDisco().Public() + } + ms.updateDiscoForNode(node.ID, node.Key, newKey, tt.updateLastSeen, tt.updateOnline) + + // We have an early escape that would not trigger the netmap updater. + synctest.Wait() + select { + case <-nu.done: + if !tt.wantUpdate { + t.Errorf("did not expect update, got: %v", nu.last) + } + default: + if tt.wantUpdate { + t.Errorf("expected update, did not get any") + } + } + + peer, ok := ms.peers[node.ID] + if !ok { + t.Fatal("node not found") + } + + keyChanged := peer.DiscoKey().Compare(oldKey.Public()) != 0 + if keyChanged != tt.wantKeyChanged { + t.Errorf("Disco key update: %t, wanted update: %t", keyChanged, tt.wantKeyChanged) + } + }) + }) + } +} + +func TestUpdateDiscoForNodeCallback(t *testing.T) { + t.Run("key_wired_through_to_updater", func(t *testing.T) { + nu := &rememberLastNetmapUpdater{ + done: make(chan any, 1), + } + ms := newTestMapSession(t, nu) + + oldKey := key.NewDisco() + + // Insert existing node + node := tailcfg.Node{ + ID: 1, + Key: key.NewNode().Public(), + DiscoKey: oldKey.Public(), + Online: new(false), + LastSeen: new(time.Unix(1, 0)), + } + + if nm := ms.netmapForResponse(&tailcfg.MapResponse{ + Peers: []*tailcfg.Node{&node}, + }); len(nm.Peers) != 1 { + t.Fatalf("node not inserted") + } + + newKey := key.NewDisco() + ms.updateDiscoForNode(node.ID, node.Key, newKey.Public(), time.Now(), false) + <-nu.done + + if nu.lastTSMPKey != node.Key || nu.lastTSMPDisco != newKey.Public() { + t.Fatalf("expected [%s]=%s, got [%s]=%s", node.Key, newKey.Public(), + nu.lastTSMPKey, nu.lastTSMPDisco) + } + }) + // Even though key stays in list of update, the updater only triggers on TSMP. + t.Run("key_not_wired_through_to_updater", func(t *testing.T) { + nu := &rememberLastNetmapUpdater{ + done: make(chan any, 1), + } + ms := newTestMapSession(t, nu) + + oldKey := key.NewDisco() + + // Insert existing node + node := tailcfg.Node{ + ID: 1, + Key: key.NewNode().Public(), + DiscoKey: oldKey.Public(), + Online: new(false), + LastSeen: new(time.Unix(1, 0)), + } + + if nm := ms.netmapForResponse(&tailcfg.MapResponse{ + Peers: []*tailcfg.Node{&node}, + }); len(nm.Peers) != 1 { + t.Fatalf("node not inserted") + } + + newKey := key.NewDisco().Public() + resp := &tailcfg.MapResponse{ + PeersChangedPatch: []*tailcfg.PeerChange{{ + NodeID: node.ID, + Key: &node.Key, + LastSeen: new(time.Now()), + Online: new(true), + DiscoKey: &newKey, + }}, + } + // Not TSMP Path, just regular injection path. + ms.HandleNonKeepAliveMapResponse(t.Context(), resp) + <-nu.done + + if !nu.lastTSMPKey.IsZero() || !nu.lastTSMPDisco.IsZero() { + t.Fatalf("expected zero keys, got [%s]=%s", + nu.lastTSMPKey, nu.lastTSMPDisco) + } + }) + + t.Run("test_deadlock", func(t *testing.T) { + nu := &rememberLastNetmapUpdater{ + done: make(chan any, 1), + } + ms := newTestMapSession(t, nu) + // Very barebones onDebug func that will let us exercise sleep command + // from control and potentially induce deadlocks. + ms.onDebug = func(ctx context.Context, d *tailcfg.Debug) error { + time.Sleep(time.Duration(d.SleepSeconds * float64(time.Second))) + return nil + } + + oldKey := key.NewDisco() + + // Insert existing node + node := tailcfg.Node{ + ID: 1, + Key: key.NewNode().Public(), + DiscoKey: oldKey.Public(), + Online: new(false), + LastSeen: new(time.Unix(1, 0)), + } + + if nm := ms.netmapForResponse(&tailcfg.MapResponse{ + Peers: []*tailcfg.Node{&node}, + }); len(nm.Peers) != 1 { + t.Fatalf("node not inserted") + } + + sleep1 := &tailcfg.MapResponse{ + Debug: &tailcfg.Debug{ + SleepSeconds: 1.0, + }, + } + ms.HandleNonKeepAliveMapResponse(t.Context(), sleep1) + + // Resembles the disco key advert subscriber running in a separate context. + go func() { + newKey := key.NewDisco() + ms.updateDiscoForNode(node.ID, node.Key, newKey.Public(), time.Now(), false) + }() + + ms.Close() + + <-nu.done + }) +} + +func TestUpdateDiscoForNodeCallbackWithFullNetmap(t *testing.T) { + now := time.Now() + oldTime := time.Unix(1, 0) + + tests := []struct { + name string + initialOnline bool + initialLastSeen time.Time + updateOnline bool + updateLastSeen time.Time + expectNewDisco bool + }{ + { + name: "disco-key-newer-lastSeen", + initialOnline: false, + initialLastSeen: oldTime, + updateOnline: false, + updateLastSeen: now, + expectNewDisco: true, + }, + { + name: "disco-key-older-lastSeen", + initialOnline: false, + initialLastSeen: now, + updateOnline: false, + updateLastSeen: oldTime, + expectNewDisco: false, + }, + { + name: "disco-key-newer-lastSeen-going-offline", + initialOnline: true, + initialLastSeen: oldTime, + updateOnline: false, + updateLastSeen: now, + expectNewDisco: true, + }, + { + name: "online-flip-newer-lastSeen", + initialOnline: false, + initialLastSeen: oldTime, + updateOnline: true, + updateLastSeen: now, + expectNewDisco: true, + }, + { + name: "local-lastseen-preserved-after-first-reconnect", + initialOnline: false, + initialLastSeen: now, + updateOnline: false, + updateLastSeen: now, + expectNewDisco: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + nu := &rememberLastNetmapUpdater{ + done: make(chan any, 1), + } + ms := newTestMapSession(t, nu) + + oldKey := key.NewDisco() + + // Initial node + node := tailcfg.Node{ + ID: 1, + Key: key.NewNode().Public(), + DiscoKey: oldKey.Public(), + Online: new(tt.initialOnline), + LastSeen: new(tt.initialLastSeen), + Name: "host.network.ts.net", + } + + if nm := ms.netmapForResponse(&tailcfg.MapResponse{ + Peers: []*tailcfg.Node{&node}, + }); len(nm.Peers) != 1 { + t.Fatalf("node not inserted") + } + + newKey := key.NewDisco() + + // Updated node + newNode := tailcfg.Node{ + ID: 1, + Key: node.Key, + DiscoKey: newKey.Public(), + Online: new(tt.updateOnline), + LastSeen: new(tt.updateLastSeen), + Name: "host.network.ts.net", + } + + ms.HandleNonKeepAliveMapResponse(t.Context(), &tailcfg.MapResponse{ + Node: &newNode, + Peers: []*tailcfg.Node{ + &newNode, + }, + }) + <-nu.done + + newMap := nu.last + if n := len(newMap.Peers); n != 1 { + t.Fatalf("netmap not right length, got %d, expected %d", n, 1) + } + + peer := newMap.Peers[0] + + expectedDisco := oldKey.Public() + if tt.expectNewDisco { + expectedDisco = newKey.Public() + } + + if peer.Key() != node.Key || peer.DiscoKey() != expectedDisco { + t.Fatalf("expected [%s]=%s, got [%s]=%s", + node.Key, expectedDisco, + peer.Key(), peer.DiscoKey(), + ) + } + }) + } +} + func first[T any](s []T) T { if len(s) == 0 { var zero T @@ -770,21 +1181,21 @@ func TestPeerChangeDiff(t *testing.T) { }, { name: "patch-lastseen", - a: &tailcfg.Node{ID: 1, LastSeen: ptr.To(time.Unix(1, 0))}, - b: &tailcfg.Node{ID: 1, LastSeen: ptr.To(time.Unix(2, 0))}, - want: &tailcfg.PeerChange{NodeID: 1, LastSeen: ptr.To(time.Unix(2, 0))}, + a: &tailcfg.Node{ID: 1, LastSeen: new(time.Unix(1, 0))}, + b: &tailcfg.Node{ID: 1, LastSeen: new(time.Unix(2, 0))}, + want: &tailcfg.PeerChange{NodeID: 1, LastSeen: new(time.Unix(2, 0))}, }, { name: "patch-online-to-true", - a: &tailcfg.Node{ID: 1, Online: ptr.To(false)}, - b: &tailcfg.Node{ID: 1, Online: ptr.To(true)}, - want: &tailcfg.PeerChange{NodeID: 1, Online: ptr.To(true)}, + a: &tailcfg.Node{ID: 1, Online: new(false)}, + b: &tailcfg.Node{ID: 1, Online: new(true)}, + want: &tailcfg.PeerChange{NodeID: 1, Online: new(true)}, }, { name: "patch-online-to-false", - a: &tailcfg.Node{ID: 1, Online: ptr.To(true)}, - b: &tailcfg.Node{ID: 1, Online: ptr.To(false)}, - want: &tailcfg.PeerChange{NodeID: 1, Online: ptr.To(false)}, + a: &tailcfg.Node{ID: 1, Online: new(true)}, + b: &tailcfg.Node{ID: 1, Online: new(false)}, + want: &tailcfg.PeerChange{NodeID: 1, Online: new(false)}, }, { name: "mix-patchable-and-not", @@ -818,14 +1229,14 @@ func TestPeerChangeDiff(t *testing.T) { }, { name: "miss-change-masq-v4", - a: &tailcfg.Node{ID: 1, SelfNodeV4MasqAddrForThisPeer: ptr.To(netip.MustParseAddr("100.64.0.1"))}, - b: &tailcfg.Node{ID: 1, SelfNodeV4MasqAddrForThisPeer: ptr.To(netip.MustParseAddr("100.64.0.2"))}, + a: &tailcfg.Node{ID: 1, SelfNodeV4MasqAddrForThisPeer: new(netip.MustParseAddr("100.64.0.1"))}, + b: &tailcfg.Node{ID: 1, SelfNodeV4MasqAddrForThisPeer: new(netip.MustParseAddr("100.64.0.2"))}, want: nil, }, { name: "miss-change-masq-v6", - a: &tailcfg.Node{ID: 1, SelfNodeV6MasqAddrForThisPeer: ptr.To(netip.MustParseAddr("2001::3456"))}, - b: &tailcfg.Node{ID: 1, SelfNodeV6MasqAddrForThisPeer: ptr.To(netip.MustParseAddr("2001::3006"))}, + a: &tailcfg.Node{ID: 1, SelfNodeV6MasqAddrForThisPeer: new(netip.MustParseAddr("2001::3456"))}, + b: &tailcfg.Node{ID: 1, SelfNodeV6MasqAddrForThisPeer: new(netip.MustParseAddr("2001::3006"))}, want: nil, }, { @@ -839,17 +1250,20 @@ func TestPeerChangeDiff(t *testing.T) { a: &tailcfg.Node{ID: 1, CapMap: tailcfg.NodeCapMap{tailcfg.CapabilityAdmin: nil}}, b: &tailcfg.Node{ID: 1, CapMap: tailcfg.NodeCapMap{tailcfg.CapabilityAdmin: nil, tailcfg.CapabilityDebug: nil}}, want: &tailcfg.PeerChange{NodeID: 1, CapMap: tailcfg.NodeCapMap{tailcfg.CapabilityAdmin: nil, tailcfg.CapabilityDebug: nil}}, - }, { + }, + { name: "patch-capmap-remove-key", a: &tailcfg.Node{ID: 1, CapMap: tailcfg.NodeCapMap{tailcfg.CapabilityAdmin: nil}}, b: &tailcfg.Node{ID: 1, CapMap: tailcfg.NodeCapMap{}}, want: &tailcfg.PeerChange{NodeID: 1, CapMap: tailcfg.NodeCapMap{}}, - }, { + }, + { name: "patch-capmap-remove-as-nil", a: &tailcfg.Node{ID: 1, CapMap: tailcfg.NodeCapMap{tailcfg.CapabilityAdmin: nil}}, b: &tailcfg.Node{ID: 1}, want: &tailcfg.PeerChange{NodeID: 1, CapMap: tailcfg.NodeCapMap{}}, - }, { + }, + { name: "patch-capmap-add-key-to-empty-map", a: &tailcfg.Node{ID: 1}, b: &tailcfg.Node{ID: 1, CapMap: tailcfg.NodeCapMap{tailcfg.CapabilityAdmin: nil}}, @@ -864,7 +1278,7 @@ func TestPeerChangeDiff(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - pc, ok := peerChangeDiff(tt.a.View(), tt.b) + pc, ok := peerChangeDiff(tt.a.View(), tt.b, nil) if tt.wantEqual { if !ok || pc != nil { t.Errorf("got (%p, %v); want (nil, true); pc=%v", pc, ok, logger.AsJSON(pc)) @@ -885,7 +1299,7 @@ func TestPeerChangeDiffAllocs(t *testing.T) { a := &tailcfg.Node{ID: 1} b := &tailcfg.Node{ID: 1} n := testing.AllocsPerRun(10000, func() { - diff, ok := peerChangeDiff(a.View(), b) + diff, ok := peerChangeDiff(a.View(), b, nil) if !ok || diff != nil { t.Fatalf("unexpected result: (%s, %v)", logger.AsJSON(diff), ok) } @@ -1079,7 +1493,7 @@ func TestUpgradeNode(t *testing.T) { t.Run(tt.name, func(t *testing.T) { var got *tailcfg.Node if tt.in != nil { - got = ptr.To(*tt.in) // shallow clone + got = new(*tt.in) // shallow clone } upgradeNode(got) if diff := cmp.Diff(tt.want, got); diff != "" { @@ -1090,7 +1504,6 @@ func TestUpgradeNode(t *testing.T) { } }) } - } func BenchmarkMapSessionDelta(b *testing.B) { @@ -1099,6 +1512,8 @@ func BenchmarkMapSessionDelta(b *testing.B) { ctx := context.Background() nu := &countingNetmapUpdater{} ms := newTestMapSession(b, nu) + // Disable log output for benchmarks to avoid races + ms.logf = func(string, ...any) {} res := &tailcfg.MapResponse{ Node: &tailcfg.Node{ ID: 1, @@ -1122,7 +1537,7 @@ func BenchmarkMapSessionDelta(b *testing.B) { {Proto: "peerapi-dns-proxy", Port: 1}, }, }).View(), - LastSeen: ptr.To(time.Unix(int64(i), 0)), + LastSeen: new(time.Unix(int64(i), 0)), }) } ms.HandleNonKeepAliveMapResponse(ctx, res) @@ -1483,3 +1898,178 @@ func TestLearnZstdOfKeepAlive(t *testing.T) { t.Fatalf("got %d zstd decodes; want %d", got, want) } } + +func TestPathDiscokeyerImplementations(t *testing.T) { + bus := eventbustest.NewBus(t) + ht := health.NewTracker(bus) + reg := new(usermetric.Registry) + e, err := wgengine.NewFakeUserspaceEngine(t.Logf, 0, ht, reg, bus) + if err != nil { + t.Fatal(err) + } + t.Cleanup(e.Close) + if _, ok := e.(patchDiscoKeyer); !ok { + t.Error("wgengine.userspaceEngine must implement patchDiscoKeyer") + } + + wd := wgengine.NewWatchdog(e) + if _, ok := wd.(patchDiscoKeyer); !ok { + t.Error("wgengine.watchdogEngine must implement patchDiscoKeyer") + } +} + +func TestPeerIDAndKeyByTailscaleIP(t *testing.T) { + peerKey1 := key.NewNode().Public() + peerKey2 := key.NewNode().Public() + + peer1 := &tailcfg.Node{ + ID: 1, + Key: peerKey1, + Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")}, + } + peer2 := &tailcfg.Node{ + ID: 2, + Key: peerKey2, + Addresses: []netip.Prefix{ + netip.MustParsePrefix("100.64.0.2/32"), + netip.MustParsePrefix("fd7a:115c::2/128"), + }, + } + + ms := newTestMapSession(t, nil) + ms.updateStateFromResponse(&tailcfg.MapResponse{ + Node: new(tailcfg.Node), + Peers: []*tailcfg.Node{peer1, peer2}, + }) + + t.Run("known_ip_peer1", func(t *testing.T) { + gotID, gotKey, ok := ms.PeerIDAndKeyByTailscaleIP(netip.MustParseAddr("100.64.0.1")) + if !ok { + t.Fatal("PeerIDAndKeyByTailscaleIP returned ok=false, want true") + } + if gotID != peer1.ID { + t.Errorf("NodeID = %v, want %v", gotID, peer1.ID) + } + if gotKey != peerKey1 { + t.Errorf("NodePublic = %v, want %v", gotKey, peerKey1) + } + }) + + t.Run("known_ip_peer2_v6", func(t *testing.T) { + gotID, gotKey, ok := ms.PeerIDAndKeyByTailscaleIP(netip.MustParseAddr("fd7a:115c::2")) + if !ok { + t.Fatal("PeerIDAndKeyByTailscaleIP returned ok=false, want true") + } + if gotID != peer2.ID { + t.Errorf("NodeID = %v, want %v", gotID, peer2.ID) + } + if gotKey != peerKey2 { + t.Errorf("NodePublic = %v, want %v", gotKey, peerKey2) + } + }) + + t.Run("unknown_ip", func(t *testing.T) { + gotID, gotKey, ok := ms.PeerIDAndKeyByTailscaleIP(netip.MustParseAddr("100.64.0.99")) + if ok { + t.Errorf("PeerIDAndKeyByTailscaleIP returned ok=true for unknown IP, got id=%v key=%v", gotID, gotKey) + } + }) +} + +func TestRemoveUnwantedDiscoUpdates(t *testing.T) { + tests := []struct { + name string + viaTSMP bool + existingOnline bool + sameKey bool + newerLastSeen bool + wantAccepted bool + }{ + { + name: "tsmp_online_peer_same_key", + viaTSMP: true, + existingOnline: true, + sameKey: true, + newerLastSeen: true, + wantAccepted: false, + }, + { + name: "not_tsmp_online_peer_same_key", + viaTSMP: false, + existingOnline: true, + sameKey: true, + newerLastSeen: true, + wantAccepted: true, + }, + { + name: "tsmp_offline_peer_same_key", + viaTSMP: true, + existingOnline: false, + sameKey: true, + newerLastSeen: true, + wantAccepted: true, + }, + { + name: "tsmp_online_peer_diff_key", + viaTSMP: true, + existingOnline: true, + sameKey: false, + newerLastSeen: true, + wantAccepted: true, + }, + { + name: "tsmp_online_peer_same_key_old_lastseen", + viaTSMP: true, + existingOnline: true, + sameKey: true, + newerLastSeen: false, + wantAccepted: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ms := newTestMapSession(t, &rememberLastNetmapUpdater{done: make(chan any, 1)}) + + existingKey := key.NewDisco().Public() + existingOnline := tt.existingOnline + initialLastSeen := time.Unix(1, 0) + + ms.updateStateFromResponse(&tailcfg.MapResponse{ + Peers: []*tailcfg.Node{{ + ID: 1, + Key: key.NewNode().Public(), + DiscoKey: existingKey, + Online: &existingOnline, + LastSeen: &initialLastSeen, + }}, + }) + + changeKey := existingKey + if !tt.sameKey { + changeKey = key.NewDisco().Public() + } + changeOnline := false // must be false to reach the new guard + updateLastSeen := time.Unix(2, 0) + if !tt.newerLastSeen { + updateLastSeen = time.Unix(0, 0) + } + + resp := &tailcfg.MapResponse{ + PeersChangedPatch: []*tailcfg.PeerChange{{ + NodeID: 1, + DiscoKey: &changeKey, + Online: &changeOnline, + LastSeen: &updateLastSeen, + }}, + } + + ms.removeUnwantedDiscoUpdates(resp, tt.viaTSMP) + + got := len(resp.PeersChangedPatch) > 0 + if got != tt.wantAccepted { + t.Errorf("accepted=%v, want %v", got, tt.wantAccepted) + } + }) + } +} diff --git a/control/controlclient/sign_supported.go b/control/controlclient/sign_supported.go index ea6fa28e34479..f3340d5a6c98f 100644 --- a/control/controlclient/sign_supported.go +++ b/control/controlclient/sign_supported.go @@ -1,9 +1,7 @@ // Copyright (c) Tailscale Inc & contributors // SPDX-License-Identifier: BSD-3-Clause -//go:build windows - -// darwin,cgo is also supported by certstore but untested, so it is not enabled. +//go:build windows || (darwin && !ios && cgo) package controlclient diff --git a/control/controlclient/sign_unsupported.go b/control/controlclient/sign_unsupported.go index ff830282e4496..a371cbaf1e609 100644 --- a/control/controlclient/sign_unsupported.go +++ b/control/controlclient/sign_unsupported.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & contributors // SPDX-License-Identifier: BSD-3-Clause -//go:build !windows +//go:build (!windows && !(darwin && cgo)) || ios package controlclient diff --git a/control/controlhttp/client.go b/control/controlhttp/client.go index e812091745ea5..2aabcbb6418a0 100644 --- a/control/controlhttp/client.go +++ b/control/controlhttp/client.go @@ -479,6 +479,9 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, optAddr netip.Ad // Disable HTTP2, since h2 can't do protocol switching. tr.TLSClientConfig.NextProtos = []string{} tr.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{} + if a.ExtraRootCAs != nil { + tr.TLSClientConfig.RootCAs = a.ExtraRootCAs + } tr.TLSClientConfig = tlsdial.Config(a.HealthTracker, tr.TLSClientConfig) if !tr.TLSClientConfig.InsecureSkipVerify { panic("unexpected") // should be set by tlsdial.Config diff --git a/control/controlhttp/constants.go b/control/controlhttp/constants.go index 26ace871c1268..efa8d84990bc1 100644 --- a/control/controlhttp/constants.go +++ b/control/controlhttp/constants.go @@ -4,6 +4,7 @@ package controlhttp import ( + "crypto/x509" "net/http" "net/url" "sync/atomic" @@ -85,6 +86,9 @@ type Dialer struct { // HealthTracker, if non-nil, is the health tracker to use. HealthTracker *health.Tracker + // ExtraRootCAs, if non-nil, specifies additional trusted root CAs for TLS. + ExtraRootCAs *x509.CertPool + // DialPlan, if set, contains instructions from the control server on // how to connect to it. If present, we will try the methods in this // plan before falling back to DNS. diff --git a/control/controlhttp/http_test.go b/control/controlhttp/http_test.go index c02ac758ebf16..7f0203cd051df 100644 --- a/control/controlhttp/http_test.go +++ b/control/controlhttp/http_test.go @@ -814,8 +814,8 @@ func runDialPlanTest(t *testing.T, plan *tailcfg.ControlDialPlan, want []netip.A // split on "|" first to remove memnet pipe suffix addrPart := raddrStr - if idx := strings.Index(raddrStr, "|"); idx >= 0 { - addrPart = raddrStr[:idx] + if before, _, ok := strings.Cut(raddrStr, "|"); ok { + addrPart = before } host, _, err2 := net.SplitHostPort(addrPart) diff --git a/control/controlknobs/controlknobs.go b/control/controlknobs/controlknobs.go index 1861a122e2f9e..77a496349c314 100644 --- a/control/controlknobs/controlknobs.go +++ b/control/controlknobs/controlknobs.go @@ -21,11 +21,6 @@ type Knobs struct { // DisableUPnP indicates whether to attempt UPnP mapping. DisableUPnP atomic.Bool - // KeepFullWGConfig is whether we should disable the lazy wireguard - // programming and instead give WireGuard the full netmap always, even for - // idle peers. - KeepFullWGConfig atomic.Bool - // RandomizeClientPort is whether control says we should randomize // the client port. RandomizeClientPort atomic.Bool @@ -62,12 +57,6 @@ type Knobs struct { // netfiltering, unless overridden by the user. LinuxForceNfTables atomic.Bool - // SeamlessKeyRenewal is whether to renew node keys without breaking connections. - // This is enabled by default in 1.90 and later, but we but we can remotely disable - // it from the control plane if there's a problem. - // http://go/seamless-key-renewal - SeamlessKeyRenewal atomic.Bool - // ProbeUDPLifetime is whether the node should probe UDP path lifetime on // the tail end of an active direct connection in magicsock. ProbeUDPLifetime atomic.Bool @@ -131,7 +120,6 @@ func (k *Knobs) UpdateFromNodeAttributes(capMap tailcfg.NodeCapMap) { } has := capMap.Contains var ( - keepFullWG = has(tailcfg.NodeAttrDebugDisableWGTrim) disableUPnP = has(tailcfg.NodeAttrDisableUPnP) randomizeClientPort = has(tailcfg.NodeAttrRandomizeClientPort) disableDeltaUpdates = has(tailcfg.NodeAttrDisableDeltaUpdates) @@ -142,8 +130,6 @@ func (k *Knobs) UpdateFromNodeAttributes(capMap tailcfg.NodeCapMap) { silentDisco = has(tailcfg.NodeAttrSilentDisco) forceIPTables = has(tailcfg.NodeAttrLinuxMustUseIPTables) forceNfTables = has(tailcfg.NodeAttrLinuxMustUseNfTables) - seamlessKeyRenewal = has(tailcfg.NodeAttrSeamlessKeyRenewal) - disableSeamlessKeyRenewal = has(tailcfg.NodeAttrDisableSeamlessKeyRenewal) probeUDPLifetime = has(tailcfg.NodeAttrProbeUDPLifetime) appCStoreRoutes = has(tailcfg.NodeAttrStoreAppCRoutes) userDialUseRoutes = has(tailcfg.NodeAttrUserDialUseRoutes) @@ -161,7 +147,6 @@ func (k *Knobs) UpdateFromNodeAttributes(capMap tailcfg.NodeCapMap) { oneCGNAT.Set(false) } - k.KeepFullWGConfig.Store(keepFullWG) k.DisableUPnP.Store(disableUPnP) k.RandomizeClientPort.Store(randomizeClientPort) k.OneCGNAT.Store(oneCGNAT) @@ -181,21 +166,6 @@ func (k *Knobs) UpdateFromNodeAttributes(capMap tailcfg.NodeCapMap) { k.DisableSkipStatusQueue.Store(disableSkipStatusQueue) k.DisableHostsFileUpdates.Store(disableHostsFileUpdates) k.ForceRegisterMagicDNSIPv4Only.Store(forceRegisterMagicDNSIPv4Only) - - // If both attributes are present, then "enable" should win. This reflects - // the history of seamless key renewal. - // - // Before 1.90, seamless was a private alpha, opt-in feature. Devices would - // only seamless do if customers opted in using the seamless renewal attr. - // - // In 1.90 and later, seamless is the default behaviour, and devices will use - // seamless unless explicitly told not to by control (e.g. if we discover - // a bug and want clients to use the prior behaviour). - // - // If a customer has opted in to the pre-1.90 seamless implementation, we - // don't want to switch it off for them -- we only want to switch it off for - // devices that haven't opted in. - k.SeamlessKeyRenewal.Store(seamlessKeyRenewal || !disableSeamlessKeyRenewal) } // AsDebugJSON returns k as something that can be marshalled with json.Marshal @@ -205,17 +175,15 @@ func (k *Knobs) AsDebugJSON() map[string]any { return nil } ret := map[string]any{} - rt := reflect.TypeFor[Knobs]() rv := reflect.ValueOf(k).Elem() // of *k - for i := 0; i < rt.NumField(); i++ { - name := rt.Field(i).Name - switch v := rv.Field(i).Addr().Interface().(type) { + for sf, fv := range rv.Fields() { + switch v := fv.Addr().Interface().(type) { case *atomic.Bool: - ret[name] = v.Load() + ret[sf.Name] = v.Load() case *syncs.AtomicValue[opt.Bool]: - ret[name] = v.Load() + ret[sf.Name] = v.Load() default: - panic(fmt.Sprintf("unknown field type %T for %v", v, name)) + panic(fmt.Sprintf("unknown field type %T for %v", v, sf.Name)) } } return ret diff --git a/control/ts2021/client.go b/control/ts2021/client.go index 0f0e7598b5591..5770bae090274 100644 --- a/control/ts2021/client.go +++ b/control/ts2021/client.go @@ -7,6 +7,7 @@ import ( "bytes" "cmp" "context" + "crypto/x509" "encoding/json" "errors" "fmt" @@ -86,6 +87,9 @@ type ClientOpts struct { // HealthTracker, if non-nil, is the health tracker to use. HealthTracker *health.Tracker + // ExtraRootCAs, if non-nil, specifies additional trusted root CAs for TLS. + ExtraRootCAs *x509.CertPool + // DialPlan, if set, is a function that should return an explicit plan // on how to connect to the server. DialPlan func() *tailcfg.ControlDialPlan @@ -252,6 +256,7 @@ func (nc *Client) dial(ctx context.Context) (*Conn, error) { Logf: nc.logf, NetMon: nc.opts.NetMon, HealthTracker: nc.opts.HealthTracker, + ExtraRootCAs: nc.opts.ExtraRootCAs, Clock: tstime.StdClock{}, } clientConn, err := chd.Dial(ctx) diff --git a/control/tsp/map.go b/control/tsp/map.go new file mode 100644 index 0000000000000..961c5dd57c0f6 --- /dev/null +++ b/control/tsp/map.go @@ -0,0 +1,415 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package tsp + +import ( + "bytes" + "cmp" + "context" + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + "sync" + "sync/atomic" + + "github.com/klauspost/compress/zstd" + "tailscale.com/control/ts2021" + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +// errSessionClosed is returned by [MapSession.Next] and +// [MapSession.NextInto] when called after [MapSession.Close]. +var errSessionClosed = errors.New("tsp: map session closed") + +// DefaultMaxMessageSize is the default cap, in bytes, on the size of a +// single compressed map response frame. See [MapOpts.MaxMessageSize]. +const DefaultMaxMessageSize = 4 << 20 + +// zstdDecoderPool is a pool of *zstd.Decoder reused across MapSessions to +// amortize the cost of setting up zstd state. Decoders are returned via +// [MapSession.Close]; entries are reclaimed by the runtime under memory +// pressure via sync.Pool semantics. +var zstdDecoderPool sync.Pool // of *zstd.Decoder + +// MapOpts contains options for sending a map request. +type MapOpts struct { + // NodeKey is the node's private key. Required. + NodeKey key.NodePrivate + + // Hostinfo is the host information to send. Optional; + // if nil, a minimal default is used. + Hostinfo *tailcfg.Hostinfo + + // Stream is whether to receive multiple MapResponses over + // the same HTTP connection. + Stream bool + + // OmitPeers is whether the client is okay with the Peers list + // being omitted in the response. + OmitPeers bool + + // MaxMessageSize is the maximum size in bytes of any single + // compressed map response frame on the wire. If zero, + // [DefaultMaxMessageSize] is used. + MaxMessageSize int64 +} + +// framedReader is an io.Reader that consumes a stream of length-prefixed +// frames (each a little-endian uint32 length followed by that many bytes) +// from r and yields only the frame payloads back-to-back. +// +// This lets us feed the concatenated zstd frames from our wire protocol +// into a single streaming zstd decoder. Zstd's file format permits +// concatenation (RFC 8478 §2), and klauspost's decoder handles it +// transparently. +// +// If onNewFrame is non-nil, it is called after each new 4-byte length +// header is successfully read. Used to reset the per-message decoded-size +// budget downstream. +type framedReader struct { + r io.Reader + maxSize int64 // per-frame compressed-size cap + remain int // bytes remaining in the current frame + onNewFrame func() +} + +func (f *framedReader) Read(p []byte) (int, error) { + if f.remain == 0 { + var hdr [4]byte + if _, err := io.ReadFull(f.r, hdr[:]); err != nil { + return 0, err + } + sz := int64(binary.LittleEndian.Uint32(hdr[:])) + if sz == 0 { + return 0, fmt.Errorf("map response: zero-length frame") + } + if sz > f.maxSize { + return 0, fmt.Errorf("map response frame size %d exceeds max %d", sz, f.maxSize) + } + f.remain = int(sz) + if f.onNewFrame != nil { + f.onNewFrame() + } + } + if len(p) > f.remain { + p = p[:f.remain] + } + n, err := f.r.Read(p) + f.remain -= n + return n, err +} + +// boundedReader is an io.Reader that yields at most remain bytes from r +// before returning an error. Call reset to raise the budget back to max, +// typically at a new message boundary. +// +// Used to cap the decoded size of a single map response so a malicious +// server can't send a small zstd frame that explodes into gigabytes of +// junk for the json.Decoder to consume. +type boundedReader struct { + r io.Reader + max int64 + remain int64 +} + +func (b *boundedReader) Read(p []byte) (int, error) { + if b.remain <= 0 { + return 0, fmt.Errorf("map response decoded size exceeds max %d", b.max) + } + if int64(len(p)) > b.remain { + p = p[:b.remain] + } + n, err := b.r.Read(p) + b.remain -= int64(n) + return n, err +} + +func (b *boundedReader) reset() { b.remain = b.max } + +// MapSession wraps an in-progress map response stream. Call Next to read +// each MapResponse. Call Close when done. +type MapSession struct { + res *http.Response + stream bool + noiseDoer func(*http.Request) (*http.Response, error) + + // inNext detects concurrent NextInto callers. It CAS-flips + // false→true on entry and back to false on exit; a failed CAS + // panics, akin to how the Go runtime detects concurrent map + // access. It does not serialize Close vs. NextInto; that's + // nextMu's job. + inNext atomic.Bool + + // nextMu is held while [MapSession.NextInto] is running jdec.Decode, + // so that Close can wait for an in-flight Decode to unwind before it + // touches zdec (Reset, pool-Put) and avoids racing with the running + // Read chain that Decode drives. + nextMu sync.Mutex + read int // guarded by nextMu + closed bool // guarded by nextMu + zdec *zstd.Decoder // reads from a framedReader wrapping res.Body + jdec *json.Decoder // reads decompressed JSON from zdec + + closeOnce sync.Once + closeErr error +} + +// NoiseRoundTrip sends an HTTP request over the Noise channel used by this map session. +func (s *MapSession) NoiseRoundTrip(req *http.Request) (*http.Response, error) { + return s.noiseDoer(req) +} + +// Next reads and returns the next MapResponse from the stream. +// For non-streaming sessions, the first call returns the single response +// and subsequent calls return io.EOF. +// For streaming sessions, Next blocks until the next response arrives +// or the server closes the connection. +// +// Each call allocates a fresh MapResponse. Callers that want to amortize +// the allocation across calls can use [MapSession.NextInto]. +// +// Next and NextInto are not safe to call concurrently from multiple +// goroutines on the same [MapSession]; a concurrent call panics, akin +// to the Go runtime's concurrent map access detection. [MapSession.Close] +// may be called concurrently to abort an in-flight Next. +func (s *MapSession) Next() (*tailcfg.MapResponse, error) { + var resp tailcfg.MapResponse + if err := s.NextInto(&resp); err != nil { + return nil, err + } + return &resp, nil +} + +// NextInto is like [MapSession.Next] but decodes the next MapResponse into +// the caller-supplied *resp rather than allocating a new one. The pointer's +// pointee is zeroed before decoding so fields from a prior response do not +// persist. +// +// For non-streaming sessions, the first call decodes the single response +// and subsequent calls return io.EOF. +// For streaming sessions, NextInto blocks until the next response arrives +// or the server closes the connection. +// +// See [MapSession.Next] for concurrency rules; those apply to NextInto too. +func (s *MapSession) NextInto(resp *tailcfg.MapResponse) error { + if !s.inNext.CompareAndSwap(false, true) { + panic("tsp: invalid concurrent call to MapSession.Next/NextInto") + } + defer s.inNext.Store(false) + + s.nextMu.Lock() + defer s.nextMu.Unlock() + if s.closed { + return errSessionClosed + } + if !s.stream && s.read > 0 { + return io.EOF + } + *resp = tailcfg.MapResponse{} + if err := s.jdec.Decode(resp); err != nil { + return err + } + s.read++ + return nil +} + +// Close returns the session's zstd decoder to the pool and closes the +// underlying HTTP response body. It is safe to call Close multiple times +// and from multiple goroutines, including while a [MapSession.Next] or +// [MapSession.NextInto] call is in flight on another goroutine (which +// will return an error once the body close propagates). +func (s *MapSession) Close() error { + // Callers are likely to race a deferred Close with a time.AfterFunc + // timeout (or similar) Close that aborts a hung Next. Without the + // Once, both Closes would Put the same *zstd.Decoder into the pool, + // corrupting it, and the Reset/Put in one would race with the + // zdec.Read that the hung Next is driving. + // + // Ordering inside the Once: close the body first to unblock any + // in-flight NextInto (its Read chain ends at res.Body and will + // return an error once it's closed). That lets NextInto unwind and + // release nextMu. Only then do we take nextMu ourselves and touch + // zdec, which is safe because no goroutine is still reading from + // it. Acquiring nextMu before closing the body would deadlock + // against a hung NextInto. + s.closeOnce.Do(func() { + s.closeErr = s.res.Body.Close() + s.nextMu.Lock() + defer s.nextMu.Unlock() + s.closed = true + s.zdec.Reset(nil) + zstdDecoderPool.Put(s.zdec) + }) + return s.closeErr +} + +// SendMapUpdateOpts contains options for [Client.SendMapUpdate]. +type SendMapUpdateOpts struct { + // NodeKey is the node's private key. Required. + NodeKey key.NodePrivate + + // DiscoKey, if non-zero, is the node's disco public key. + // Peers use it to verify disco pings from this node, which is + // what enables direct (non-DERP) paths. + DiscoKey key.DiscoPublic + + // Hostinfo is the host information to send. Optional; + // if nil, a minimal default is used. + Hostinfo *tailcfg.Hostinfo +} + +// SendMapUpdate sends a one-shot, non-streaming MapRequest to push small +// updates (such as the node's endpoints, hostinfo, or disco public key) to the +// coordination server without starting or disturbing a streaming map session. +func (c *Client) SendMapUpdate(ctx context.Context, opts SendMapUpdateOpts) error { + if opts.NodeKey.IsZero() { + return fmt.Errorf("NodeKey is required") + } + + hi := opts.Hostinfo + if hi == nil { + hi = defaultHostinfo() + } + + mapReq := tailcfg.MapRequest{ + Version: tailcfg.CurrentCapabilityVersion, + NodeKey: opts.NodeKey.Public(), + DiscoKey: opts.DiscoKey, + Hostinfo: hi, + Compress: "zstd", + + // A lite update that lets the server persist our state without breaking + // any existing streaming map session. See the [tailcfg.MapResponse] + // OmitPeers docs. + OmitPeers: true, + Stream: false, + ReadOnly: false, + } + + body, err := json.Marshal(mapReq) + if err != nil { + return fmt.Errorf("encoding map request: %w", err) + } + + nc, err := c.noiseClient(ctx) + if err != nil { + return fmt.Errorf("establishing noise connection: %w", err) + } + + url := c.serverURL + "/machine/map" + url = strings.Replace(url, "http:", "https:", 1) + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body)) + if err != nil { + return fmt.Errorf("creating map request: %w", err) + } + ts2021.AddLBHeader(req, opts.NodeKey.Public()) + + res, err := nc.Do(req) + if err != nil { + return fmt.Errorf("map request: %w", err) + } + defer res.Body.Close() + + if res.StatusCode != 200 { + msg, _ := io.ReadAll(res.Body) + return fmt.Errorf("map request: http %d: %.200s", + res.StatusCode, strings.TrimSpace(string(msg))) + } + io.Copy(io.Discard, res.Body) + return nil +} + +// Map sends a map request to the coordination server and returns a MapSession +// for reading the framed, zstd-compressed response(s). +func (c *Client) Map(ctx context.Context, opts MapOpts) (*MapSession, error) { + if opts.NodeKey.IsZero() { + return nil, fmt.Errorf("NodeKey is required") + } + + hi := opts.Hostinfo + if hi == nil { + hi = defaultHostinfo() + } + + mapReq := tailcfg.MapRequest{ + Version: tailcfg.CurrentCapabilityVersion, + NodeKey: opts.NodeKey.Public(), + Hostinfo: hi, + Stream: opts.Stream, + Compress: "zstd", + OmitPeers: opts.OmitPeers, + // Streaming requires the server to track us as "connected", + // which in turn requires ReadOnly=false. Non-streaming polls + // stay ReadOnly to minimize side effects. + ReadOnly: !opts.Stream, + } + + body, err := json.Marshal(mapReq) + if err != nil { + return nil, fmt.Errorf("encoding map request: %w", err) + } + + nc, err := c.noiseClient(ctx) + if err != nil { + return nil, fmt.Errorf("establishing noise connection: %w", err) + } + + url := c.serverURL + "/machine/map" + url = strings.Replace(url, "http:", "https:", 1) + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("creating map request: %w", err) + } + ts2021.AddLBHeader(req, opts.NodeKey.Public()) + + res, err := nc.Do(req) + if err != nil { + return nil, fmt.Errorf("map request: %w", err) + } + + if res.StatusCode != 200 { + msg, _ := io.ReadAll(res.Body) + res.Body.Close() + return nil, fmt.Errorf("map request: http %d: %.200s", + res.StatusCode, strings.TrimSpace(string(msg))) + } + + maxMessageSize := cmp.Or(opts.MaxMessageSize, DefaultMaxMessageSize) + bounded := &boundedReader{max: maxMessageSize, remain: maxMessageSize} + fr := &framedReader{ + r: res.Body, + maxSize: maxMessageSize, + onNewFrame: bounded.reset, + } + + zdec, _ := zstdDecoderPool.Get().(*zstd.Decoder) + if zdec != nil { + if err := zdec.Reset(fr); err != nil { + // Reset can fail if the previous stream is in a bad state; drop + // the decoder and create a fresh one. + zdec = nil + } + } + if zdec == nil { + zdec, err = zstd.NewReader(fr, zstd.WithDecoderConcurrency(1)) + if err != nil { + res.Body.Close() + return nil, fmt.Errorf("creating zstd decoder: %w", err) + } + } + bounded.r = zdec + + return &MapSession{ + res: res, + stream: opts.Stream, + noiseDoer: nc.Do, + zdec: zdec, + jdec: json.NewDecoder(bounded), + }, nil +} diff --git a/control/tsp/map_test.go b/control/tsp/map_test.go new file mode 100644 index 0000000000000..ddfde3971800c --- /dev/null +++ b/control/tsp/map_test.go @@ -0,0 +1,409 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package tsp + +import ( + "bytes" + "context" + "encoding/binary" + "encoding/json" + "fmt" + "io" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/klauspost/compress/zstd" + "tailscale.com/health" + "tailscale.com/tailcfg" + "tailscale.com/tstest/integration/testcontrol" + "tailscale.com/types/key" +) + +func TestMapAgainstTestControl(t *testing.T) { + ctrl := &testcontrol.Server{} + ctrl.HTTPTestServer = httptest.NewUnstartedServer(ctrl) + ctrl.HTTPTestServer.Start() + t.Cleanup(ctrl.HTTPTestServer.Close) + baseURL := ctrl.HTTPTestServer.URL + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + ht := new(health.Tracker) + + serverKey, err := DiscoverServerKey(ctx, baseURL) + if err != nil { + t.Fatalf("DiscoverServerKey: %v", err) + } + + register := func(hostname string) (nodeKey key.NodePrivate, machineKey key.MachinePrivate) { + t.Helper() + nodeKey = key.NewNode() + machineKey = key.NewMachine() + c, err := NewClient(ClientOpts{ + ServerURL: baseURL, + MachineKey: machineKey, + HealthTracker: ht, + }) + if err != nil { + t.Fatalf("NewClient %s: %v", hostname, err) + } + defer c.Close() + c.SetControlPublicKey(serverKey) + if _, err := c.Register(ctx, RegisterOpts{ + NodeKey: nodeKey, + Hostinfo: &tailcfg.Hostinfo{Hostname: hostname}, + }); err != nil { + t.Fatalf("Register %s: %v", hostname, err) + } + return nodeKey, machineKey + } + + nodeKeyA, machineKeyA := register("a") + nodeKeyB, _ := register("b") + + clientA, err := NewClient(ClientOpts{ + ServerURL: baseURL, + MachineKey: machineKeyA, + HealthTracker: ht, + }) + if err != nil { + t.Fatalf("NewClient A: %v", err) + } + defer clientA.Close() + clientA.SetControlPublicKey(serverKey) + + session, err := clientA.Map(ctx, MapOpts{ + NodeKey: nodeKeyA, + Hostinfo: &tailcfg.Hostinfo{Hostname: "a"}, + Stream: true, + }) + if err != nil { + t.Fatalf("Map: %v", err) + } + defer session.Close() + + // nextNonKeepalive returns the next non-keepalive MapResponse, to keep + // the test robust if a server-side keepalive arrives mid-test. + nextNonKeepalive := func() *tailcfg.MapResponse { + t.Helper() + for { + resp, err := session.Next() + if err != nil { + t.Fatalf("session.Next: %v", err) + } + if resp.KeepAlive { + continue + } + return resp + } + } + + // First MapResponse: expect node A as self and node B in Peers. + first := nextNonKeepalive() + if first.Node == nil { + t.Fatal("first response has nil Node") + } + if got, want := first.Node.Key, nodeKeyA.Public(); got != want { + t.Errorf("first Node.Key = %v, want %v", got, want) + } + var foundB bool + for _, p := range first.Peers { + if p.Key == nodeKeyB.Public() { + foundB = true + break + } + } + if !foundB { + t.Errorf("peer B (%v) not in first response's Peers (%d peers)", nodeKeyB.Public(), len(first.Peers)) + } + + // Inject raw MapResponses and verify they come out the reader, in order. + // msgToSend is single-slot, so we must consume each before injecting the next. + for i := range 3 { + want := fmt.Sprintf("injected-%d.example.com", i) + inject := &tailcfg.MapResponse{Domain: want} + if !ctrl.AddRawMapResponse(nodeKeyA.Public(), inject) { + t.Fatalf("AddRawMapResponse %d: node not connected", i) + } + got := nextNonKeepalive() + if got.Domain != want { + t.Errorf("injected %d: got Domain=%q, want %q", i, got.Domain, want) + } + } +} + +// TestSendMapUpdateAgainstTestControl verifies that a [Client.SendMapUpdate] +// call from one node lands on the coordination server and that peer nodes +// subsequently observe the updated DiscoKey via their own streaming map poll. +func TestSendMapUpdateAgainstTestControl(t *testing.T) { + ctrl := &testcontrol.Server{} + ctrl.HTTPTestServer = httptest.NewUnstartedServer(ctrl) + ctrl.HTTPTestServer.Start() + t.Cleanup(ctrl.HTTPTestServer.Close) + baseURL := ctrl.HTTPTestServer.URL + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + ht := new(health.Tracker) + + serverKey, err := DiscoverServerKey(ctx, baseURL) + if err != nil { + t.Fatalf("DiscoverServerKey: %v", err) + } + + register := func(hostname string) (nodeKey key.NodePrivate, machineKey key.MachinePrivate) { + t.Helper() + nodeKey = key.NewNode() + machineKey = key.NewMachine() + c, err := NewClient(ClientOpts{ + ServerURL: baseURL, + MachineKey: machineKey, + HealthTracker: ht, + }) + if err != nil { + t.Fatalf("NewClient %s: %v", hostname, err) + } + defer c.Close() + c.SetControlPublicKey(serverKey) + if _, err := c.Register(ctx, RegisterOpts{ + NodeKey: nodeKey, + Hostinfo: &tailcfg.Hostinfo{Hostname: hostname}, + }); err != nil { + t.Fatalf("Register %s: %v", hostname, err) + } + return nodeKey, machineKey + } + + nodeKeyA, machineKeyA := register("a") + nodeKeyB, machineKeyB := register("b") + + // B starts a streaming map poll so we can observe updates about peer A. + clientB, err := NewClient(ClientOpts{ + ServerURL: baseURL, + MachineKey: machineKeyB, + HealthTracker: ht, + }) + if err != nil { + t.Fatalf("NewClient B: %v", err) + } + defer clientB.Close() + clientB.SetControlPublicKey(serverKey) + + session, err := clientB.Map(ctx, MapOpts{ + NodeKey: nodeKeyB, + Hostinfo: &tailcfg.Hostinfo{Hostname: "b"}, + Stream: true, + }) + if err != nil { + t.Fatalf("Map B: %v", err) + } + defer session.Close() + + nextNonKeepalive := func() *tailcfg.MapResponse { + t.Helper() + for { + resp, err := session.Next() + if err != nil { + t.Fatalf("session.Next: %v", err) + } + if resp.KeepAlive { + continue + } + return resp + } + } + + // Drain B's initial MapResponse. A should be present as a peer with a + // zero DiscoKey (it never pushed one). + first := nextNonKeepalive() + var initialA *tailcfg.Node + for _, p := range first.Peers { + if p.Key == nodeKeyA.Public() { + initialA = p + break + } + } + if initialA == nil { + t.Fatalf("peer A (%v) not in B's first MapResponse", nodeKeyA.Public()) + } + if !initialA.DiscoKey.IsZero() { + t.Fatalf("peer A initial DiscoKey = %v, want zero", initialA.DiscoKey) + } + + // A pushes its disco key via SendMapUpdate. + clientA, err := NewClient(ClientOpts{ + ServerURL: baseURL, + MachineKey: machineKeyA, + HealthTracker: ht, + }) + if err != nil { + t.Fatalf("NewClient A: %v", err) + } + defer clientA.Close() + clientA.SetControlPublicKey(serverKey) + + wantDisco := key.NewDisco().Public() + if err := clientA.SendMapUpdate(ctx, SendMapUpdateOpts{ + NodeKey: nodeKeyA, + DiscoKey: wantDisco, + Hostinfo: &tailcfg.Hostinfo{Hostname: "a"}, + }); err != nil { + t.Fatalf("SendMapUpdate: %v", err) + } + + // B should now observe A's new DiscoKey in a subsequent MapResponse. + for { + resp := nextNonKeepalive() + for _, p := range resp.Peers { + if p.Key != nodeKeyA.Public() { + continue + } + if p.DiscoKey == wantDisco { + return // success + } + } + } +} + +// newTestPipeline builds the same framedReader → zstd → boundedReader → +// json.Decoder pipeline that [Client.Map] builds for a live session, but +// feeds it from a raw byte slice. Returned jdec can be used with Decode to +// pull out MapResponses. +func newTestPipeline(t testing.TB, wire []byte, maxMessageSize int64) *json.Decoder { + t.Helper() + bounded := &boundedReader{max: maxMessageSize, remain: maxMessageSize} + fr := &framedReader{ + r: bytes.NewReader(wire), + maxSize: maxMessageSize, + onNewFrame: bounded.reset, + } + zdec, err := zstd.NewReader(fr, zstd.WithDecoderConcurrency(1)) + if err != nil { + t.Fatalf("zstd.NewReader: %v", err) + } + t.Cleanup(zdec.Close) + bounded.r = zdec + return json.NewDecoder(bounded) +} + +// zstdFrame returns a zstd-compressed frame of b. +func zstdFrame(t testing.TB, b []byte) []byte { + t.Helper() + enc, err := zstd.NewWriter(io.Discard, zstd.WithEncoderConcurrency(1)) + if err != nil { + t.Fatalf("zstd.NewWriter: %v", err) + } + defer enc.Close() + return enc.EncodeAll(b, nil) +} + +// wireFrame writes a 4-byte little-endian length prefix plus payload to buf. +func wireFrame(buf *bytes.Buffer, payload []byte) { + var hdr [4]byte + binary.LittleEndian.PutUint32(hdr[:], uint32(len(payload))) + buf.Write(hdr[:]) + buf.Write(payload) +} + +// TestMapFrameSizeTooLarge verifies that a 4-byte length prefix claiming +// a frame larger than the configured cap is rejected before any payload +// bytes are read from the stream. +func TestMapFrameSizeTooLarge(t *testing.T) { + const max = 4 << 20 + var wire bytes.Buffer + var hdr [4]byte + binary.LittleEndian.PutUint32(hdr[:], (max + 1)) + wire.Write(hdr[:]) + + jdec := newTestPipeline(t, wire.Bytes(), max) + var resp tailcfg.MapResponse + err := jdec.Decode(&resp) + if err == nil { + t.Fatal("Decode: got nil error, want frame-too-large") + } + if !strings.Contains(err.Error(), "exceeds max") { + t.Errorf("Decode error = %q, want one containing %q", err, "exceeds max") + } +} + +// TestMapDecodedSizeTooLarge verifies that a small on-wire frame (well +// under the cap) which decompresses into a huge JSON payload is rejected. +// This is the "zstd bomb" case: a tiny compressed frame that would +// explode into a huge decoded payload for json.Decoder to consume. +func TestMapDecodedSizeTooLarge(t *testing.T) { + const max = 4 << 20 + big := strings.Repeat("a", 5<<20) // 5 MiB of 'a' + raw, err := json.Marshal(&tailcfg.MapResponse{Domain: big}) + if err != nil { + t.Fatal(err) + } + if int64(len(raw)) <= max { + t.Fatalf("raw JSON unexpectedly small: %d", len(raw)) + } + compressed := zstdFrame(t, raw) + if int64(len(compressed)) >= max { + t.Fatalf("compressed too large (%d); test needs a more compressible payload", len(compressed)) + } + + var wire bytes.Buffer + wireFrame(&wire, compressed) + + jdec := newTestPipeline(t, wire.Bytes(), max) + var resp tailcfg.MapResponse + err = jdec.Decode(&resp) + if err == nil { + t.Fatal("Decode: got nil error, want decoded-size-exceeded") + } + if !strings.Contains(err.Error(), "exceeds max") { + t.Errorf("Decode error = %q, want one containing %q", err, "exceeds max") + } +} + +// TestMapBudgetResetsBetweenFrames verifies that the per-message decoded +// budget is reset at each new frame boundary. Two consecutive 3-MiB frames +// should both decode successfully under a 4-MiB per-frame cap. Without the +// reset, the second frame would fail (remaining budget after frame 1 = +// 4MiB - 3MiB = 1MiB, and we'd try to read 3MiB more). +func TestMapBudgetResetsBetweenFrames(t *testing.T) { + const max = 4 << 20 + payload := strings.Repeat("a", 3<<20) + r1 := &tailcfg.MapResponse{Domain: payload + "-one"} + r2 := &tailcfg.MapResponse{Domain: payload + "-two"} + + var wire bytes.Buffer + for _, r := range []*tailcfg.MapResponse{r1, r2} { + raw, err := json.Marshal(r) + if err != nil { + t.Fatal(err) + } + if int64(len(raw)) >= max { + t.Fatalf("raw JSON size %d >= max %d; would fail budget check by itself", len(raw), max) + } + compressed := zstdFrame(t, raw) + if int64(len(compressed)) >= max { + t.Fatalf("compressed size %d >= max %d", len(compressed), max) + } + wireFrame(&wire, compressed) + } + + jdec := newTestPipeline(t, wire.Bytes(), max) + + var got1, got2 tailcfg.MapResponse + if err := jdec.Decode(&got1); err != nil { + t.Fatalf("first Decode: %v", err) + } + if got1.Domain != r1.Domain { + t.Errorf("first Domain mismatch (len %d vs %d)", len(got1.Domain), len(r1.Domain)) + } + if err := jdec.Decode(&got2); err != nil { + t.Fatalf("second Decode: %v", err) + } + if got2.Domain != r2.Domain { + t.Errorf("second Domain mismatch (len %d vs %d)", len(got2.Domain), len(r2.Domain)) + } +} diff --git a/control/tsp/nodefile.go b/control/tsp/nodefile.go new file mode 100644 index 0000000000000..8cae11ba9fe67 --- /dev/null +++ b/control/tsp/nodefile.go @@ -0,0 +1,105 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package tsp + +import ( + "encoding/json" + "fmt" + "os" + + "tailscale.com/types/key" +) + +// ServerInfo identifies a coordination server by its URL and Noise public key. +type ServerInfo struct { + // URL is the base URL of the coordination server, without any path + // (e.g. "https://controlplane.tailscale.com"). + // + // There is no default value; a URL must always be supplied. + URL string `json:"server_url"` + + // Key is the server's Noise public key, used to establish an encrypted + // channel between the client and the coordination server. + Key key.MachinePublic `json:"server_key"` +} + +// NodeFile is the JSON structure for a node credentials file. It contains +// the private keys that authenticate a node to a coordination server. +// +// Example: +// +// { +// "node_key": "privkey:...", +// "machine_key": "privkey:...", +// "server_url": "https://controlplane.tailscale.com", +// "server_key": "mkey:..." +// } +// +// Note that node and machine private keys share the same "privkey:" +// textual form; they are disambiguated by the surrounding JSON field +// names rather than by any prefix in the key itself. +type NodeFile struct { + // NodeKey is the node's WireGuard private key. The corresponding + // public key identifies this node to other peers. + NodeKey key.NodePrivate `json:"node_key"` + + // MachineKey is the machine's private key. It authenticates this + // machine to the coordination server over Noise. + MachineKey key.MachinePrivate `json:"machine_key"` + + ServerInfo // server_url and server_key +} + +// ReadNodeFile reads and parses a node JSON file. +func ReadNodeFile(path string) (NodeFile, error) { + data, err := os.ReadFile(path) + if err != nil { + return NodeFile{}, err + } + var nf NodeFile + if err := json.Unmarshal(data, &nf); err != nil { + return NodeFile{}, fmt.Errorf("parsing node file %q: %w", path, err) + } + return nf, nil +} + +// WriteNodeFile writes a node JSON file. The file is created with mode 0600. +func WriteNodeFile(path string, nf NodeFile) error { + if err := nf.Check(); err != nil { + return fmt.Errorf("invalid NodeFile: %w", err) + } + return os.WriteFile(path, nf.AsJSON(), 0600) +} + +// AsJSON returns nf as a pretty-printed JSON object, terminated by a newline. +// +// It always succeeds and always returns a valid JSON object. It does not +// validate that the fields of nf are non-zero; it is the caller's +// responsibility to call [NodeFile.Check] first if they want to reject +// incomplete NodeFiles. +func (nf NodeFile) AsJSON() []byte { + out, err := json.MarshalIndent(nf, "", " ") + if err != nil { + panic(fmt.Sprintf("NodeFile.AsJSON: %v", err)) // unreachable: all fields marshal successfully + } + return append(out, '\n') +} + +// Check reports whether nf has all required fields set. +// It returns an error describing the first zero-valued field, if any. +func (nf NodeFile) Check() error { + if nf.NodeKey.IsZero() { + return fmt.Errorf("node_key is missing") + } + if nf.MachineKey.IsZero() { + return fmt.Errorf("machine_key is missing") + } + if nf.URL == "" { + return fmt.Errorf("server_url is missing") + } + if nf.ServerInfo.Key.IsZero() { + return fmt.Errorf("server_key is missing") + } + return nil +} diff --git a/control/tsp/nodefile_test.go b/control/tsp/nodefile_test.go new file mode 100644 index 0000000000000..4a019f25fcbd4 --- /dev/null +++ b/control/tsp/nodefile_test.go @@ -0,0 +1,116 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package tsp + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + + "tailscale.com/types/key" +) + +func TestNodeFileRoundTrip(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "node.json") + + nf := NodeFile{ + NodeKey: key.NewNode(), + MachineKey: key.NewMachine(), + ServerInfo: ServerInfo{ + URL: "https://controlplane.tailscale.com", + Key: key.NewMachine().Public(), + }, + } + + if err := WriteNodeFile(path, nf); err != nil { + t.Fatalf("WriteNodeFile: %v", err) + } + + got, err := ReadNodeFile(path) + if err != nil { + t.Fatalf("ReadNodeFile: %v", err) + } + if !got.NodeKey.Equal(nf.NodeKey) { + t.Errorf("node key mismatch") + } + if !got.MachineKey.Equal(nf.MachineKey) { + t.Errorf("machine key mismatch") + } + if got.URL != nf.URL { + t.Errorf("server URL = %q, want %q", got.URL, nf.URL) + } + if got.ServerInfo.Key != nf.ServerInfo.Key { + t.Errorf("server key mismatch") + } +} + +// TestNodeFileFormat verifies that ReadNodeFile can parse a fixed JSON literal, +// ensuring we don't accidentally change the on-disk format. +func TestNodeFileFormat(t *testing.T) { + const fileContents = `{ + "node_key": "privkey:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", + "machine_key": "privkey:fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210", + "server_url": "https://controlplane.tailscale.com", + "server_key": "mkey:1111111111111111111111111111111111111111111111111111111111111111" +}` + dir := t.TempDir() + path := filepath.Join(dir, "node.json") + if err := os.WriteFile(path, []byte(fileContents), 0600); err != nil { + t.Fatal(err) + } + + nf, err := ReadNodeFile(path) + if err != nil { + t.Fatalf("ReadNodeFile: %v", err) + } + if nf.NodeKey.IsZero() { + t.Error("node key is zero") + } + if nf.MachineKey.IsZero() { + t.Error("machine key is zero") + } + if nf.URL != "https://controlplane.tailscale.com" { + t.Errorf("server URL = %q", nf.URL) + } + if nf.ServerInfo.Key.IsZero() { + t.Error("server key is zero") + } +} + +// TestNodeFileWriteFormat verifies that WriteNodeFile produces the expected +// JSON field names. +func TestNodeFileWriteFormat(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "node.json") + + nf := NodeFile{ + NodeKey: key.NewNode(), + MachineKey: key.NewMachine(), + ServerInfo: ServerInfo{ + URL: "https://example.com", + Key: key.NewMachine().Public(), + }, + } + + if err := WriteNodeFile(path, nf); err != nil { + t.Fatalf("WriteNodeFile: %v", err) + } + + data, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + + var raw map[string]json.RawMessage + if err := json.Unmarshal(data, &raw); err != nil { + t.Fatalf("parsing written JSON: %v", err) + } + for _, field := range []string{"node_key", "machine_key", "server_url", "server_key"} { + if _, ok := raw[field]; !ok { + t.Errorf("missing JSON field %q in written file", field) + } + } +} diff --git a/control/tsp/register.go b/control/tsp/register.go new file mode 100644 index 0000000000000..0d2baf75fe3ef --- /dev/null +++ b/control/tsp/register.go @@ -0,0 +1,116 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package tsp + +import ( + "bytes" + "cmp" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "tailscale.com/control/ts2021" + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +// RegisterOpts contains options for registering a node. +type RegisterOpts struct { + // NodeKey is the node's private key. Required. + NodeKey key.NodePrivate + + // Hostinfo is the host information to send. Optional; + // if nil, a minimal default is used. + Hostinfo *tailcfg.Hostinfo + + // Ephemeral marks the node as ephemeral. + Ephemeral bool + + // AuthKey is a pre-authorized auth key. + AuthKey string + + // Tags is a list of ACL tags to request. + Tags []string + + // MaxResponseSize is the maximum size in bytes of the register + // response body. If zero, [DefaultMaxMessageSize] is used. + MaxResponseSize int64 +} + +// Register sends a registration request to the coordination server +// and returns the response. +func (c *Client) Register(ctx context.Context, opts RegisterOpts) (*tailcfg.RegisterResponse, error) { + hi := opts.Hostinfo + if hi == nil { + hi = defaultHostinfo() + } + if len(opts.Tags) > 0 { + hi.RequestTags = opts.Tags + } + + regReq := tailcfg.RegisterRequest{ + Version: tailcfg.CurrentCapabilityVersion, + NodeKey: opts.NodeKey.Public(), + Hostinfo: hi, + Ephemeral: opts.Ephemeral, + } + if opts.AuthKey != "" { + regReq.Auth = &tailcfg.RegisterResponseAuth{ + AuthKey: opts.AuthKey, + } + } + + body, err := json.Marshal(regReq) + if err != nil { + return nil, fmt.Errorf("encoding register request: %w", err) + } + + nc, err := c.noiseClient(ctx) + if err != nil { + return nil, fmt.Errorf("establishing noise connection: %w", err) + } + + url := c.serverURL + "/machine/register" + url = strings.Replace(url, "http:", "https:", 1) + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("creating register request: %w", err) + } + ts2021.AddLBHeader(req, opts.NodeKey.Public()) + + res, err := nc.Do(req) + if err != nil { + return nil, fmt.Errorf("register request: %w", err) + } + defer res.Body.Close() + + maxResponseSize := cmp.Or(opts.MaxResponseSize, DefaultMaxMessageSize) + + if res.StatusCode != 200 { + msg, _ := io.ReadAll(io.LimitReader(res.Body, maxResponseSize)) + return nil, fmt.Errorf("register request: http %d: %.200s", + res.StatusCode, strings.TrimSpace(string(msg))) + } + + // Read up to maxResponseSize+1 so we can distinguish "exactly at cap" from + // "over the cap" rather than relying on a truncated json parse error. + data, err := io.ReadAll(io.LimitReader(res.Body, maxResponseSize+1)) + if err != nil { + return nil, fmt.Errorf("reading register response: %w", err) + } + if int64(len(data)) > maxResponseSize { + return nil, fmt.Errorf("register response exceeds max %d", maxResponseSize) + } + var resp tailcfg.RegisterResponse + if err := json.Unmarshal(data, &resp); err != nil { + return nil, fmt.Errorf("decoding register response: %w", err) + } + if resp.Error != "" { + return nil, fmt.Errorf("register: %s", resp.Error) + } + return &resp, nil +} diff --git a/control/tsp/tsp.go b/control/tsp/tsp.go new file mode 100644 index 0000000000000..23f2fc26115b9 --- /dev/null +++ b/control/tsp/tsp.go @@ -0,0 +1,257 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// Package tsp provides a client for speaking the Tailscale protocol +// to a coordination server over Noise. +package tsp + +import ( + "bufio" + "bytes" + "cmp" + "context" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "strconv" + "sync" + + "tailscale.com/control/ts2021" + "tailscale.com/health" + "tailscale.com/ipn" + "tailscale.com/net/tsdial" + "tailscale.com/tailcfg" + "tailscale.com/types/key" + "tailscale.com/types/logger" + "tailscale.com/version" +) + +// DefaultServerURL is the default coordination server base URL, +// used when ClientOpts.ServerURL is empty. +const DefaultServerURL = ipn.DefaultControlURL + +// ClientOpts contains options for creating a new Client. +type ClientOpts struct { + // ServerURL is the base URL of the coordination server + // (e.g. "https://controlplane.tailscale.com"). + // If empty, DefaultServerURL is used. + ServerURL string + + // MachineKey is this node's machine private key. Required. + MachineKey key.MachinePrivate + + // Logf is the log function. If nil, logger.Discard is used. + Logf logger.Logf + + // HealthTracker, if non-nil, is the health tracker passed through + // to the underlying noise client. May be nil. + HealthTracker *health.Tracker +} + +// Client is a Tailscale protocol client that speaks to a coordination +// server over Noise. +type Client struct { + opts ClientOpts + serverURL string + logf logger.Logf + + mu sync.Mutex + nc *ts2021.Client // nil until noiseClient called + serverPub key.MachinePublic // zero until set or discovered +} + +// NewClient creates a new Client configured to talk to the coordination server +// specified in opts. It performs no I/O; the server's public key is discovered +// lazily on first use or can be set explicitly via SetControlPublicKey. +func NewClient(opts ClientOpts) (*Client, error) { + if opts.MachineKey.IsZero() { + return nil, fmt.Errorf("MachineKey is required") + } + logf := opts.Logf + if logf == nil { + logf = logger.Discard + } + return &Client{ + opts: opts, + serverURL: cmp.Or(opts.ServerURL, DefaultServerURL), + logf: logf, + }, nil +} + +// SetControlPublicKey sets the server's public key, bypassing lazy discovery. +// Any existing noise client is invalidated and will be re-created on next use. +func (c *Client) SetControlPublicKey(k key.MachinePublic) { + c.mu.Lock() + defer c.mu.Unlock() + c.serverPub = k + c.nc = nil +} + +// DiscoverServerKey fetches the server's public key from the coordination +// server and stores it for subsequent use. Any existing noise client is +// invalidated. +func (c *Client) DiscoverServerKey(ctx context.Context) (key.MachinePublic, error) { + k, err := DiscoverServerKey(ctx, c.serverURL) + if err != nil { + return key.MachinePublic{}, err + } + c.mu.Lock() + defer c.mu.Unlock() + c.serverPub = k + c.nc = nil + return k, nil +} + +// DiscoverServerKey fetches the coordination server's public key from the +// given server URL. It is a standalone function that requires no client state. +func DiscoverServerKey(ctx context.Context, serverURL string) (key.MachinePublic, error) { + serverURL = cmp.Or(serverURL, DefaultServerURL) + keysURL := serverURL + "/key?v=" + strconv.Itoa(int(tailcfg.CurrentCapabilityVersion)) + req, err := http.NewRequestWithContext(ctx, "GET", keysURL, nil) + if err != nil { + return key.MachinePublic{}, fmt.Errorf("creating key request: %w", err) + } + res, err := http.DefaultClient.Do(req) + if err != nil { + return key.MachinePublic{}, fmt.Errorf("fetching server key: %w", err) + } + defer res.Body.Close() + if res.StatusCode != 200 { + return key.MachinePublic{}, fmt.Errorf("fetching server key: %s", res.Status) + } + var keys struct { + PublicKey key.MachinePublic + } + if err := json.NewDecoder(res.Body).Decode(&keys); err != nil { + return key.MachinePublic{}, fmt.Errorf("decoding server key: %w", err) + } + return keys.PublicKey, nil +} + +// noiseClient returns the ts2021 noise client, creating it lazily if needed. +// If the server's public key is not yet known, it is discovered via HTTP. +func (c *Client) noiseClient(ctx context.Context) (*ts2021.Client, error) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.nc != nil { + return c.nc, nil + } + + if c.serverPub.IsZero() { + // Discover server key without holding the lock, to avoid blocking + // other callers during the HTTP request. + c.mu.Unlock() + k, err := DiscoverServerKey(ctx, c.serverURL) + c.mu.Lock() + if err != nil { + return nil, err + } + // Re-check: another goroutine may have set it while we were unlocked. + if c.serverPub.IsZero() { + c.serverPub = k + } + // If nc was created by another goroutine while unlocked, use it. + if c.nc != nil { + return c.nc, nil + } + } + + nc, err := ts2021.NewClient(ts2021.ClientOpts{ + ServerURL: c.serverURL, + PrivKey: c.opts.MachineKey, + ServerPubKey: c.serverPub, + Dialer: tsdial.NewFromFuncForDebug(c.logf, (&net.Dialer{}).DialContext), + Logf: c.logf, + HealthTracker: c.opts.HealthTracker, + }) + if err != nil { + return nil, fmt.Errorf("creating noise client: %w", err) + } + c.nc = nc + return nc, nil +} + +// AnswerC2NPing handles a c2n PingRequest from the control plane by parsing the +// embedded HTTP request in the payload, routing it locally, and POSTing the HTTP +// response back to pr.URL using doNoiseRequest. The POST is done in a new +// goroutine so this method does not block. +// +// It reports whether the ping was handled. Unhandled pings (nil pr, non-c2n +// types, or unrecognized c2n paths) return false. +func (c *Client) AnswerC2NPing(ctx context.Context, pr *tailcfg.PingRequest, doNoiseRequest func(*http.Request) (*http.Response, error)) (handled bool) { + if pr == nil || pr.Types != "c2n" { + return false + } + + // Parse the HTTP request from the payload. + httpReq, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(pr.Payload))) + if err != nil { + c.logf("parsing c2n ping payload: %v", err) + return false + } + + // Route the request locally. + var httpResp *http.Response + switch httpReq.URL.Path { + case "/echo": + body, _ := io.ReadAll(httpReq.Body) + httpResp = &http.Response{ + StatusCode: 200, + Status: "200 OK", + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(body)), + ContentLength: int64(len(body)), + } + default: + c.logf("ignoring c2n ping request for unhandled path %q", httpReq.URL.Path) + return false + } + + // Serialize the HTTP response. + var buf bytes.Buffer + if err := httpResp.Write(&buf); err != nil { + c.logf("serializing c2n ping response: %v", err) + return false + } + + // Send the response back to the control plane over the Noise channel. + go func() { + req, err := http.NewRequestWithContext(ctx, "POST", pr.URL, &buf) + if err != nil { + c.logf("creating c2n ping reply request: %v", err) + return + } + resp, err := doNoiseRequest(req) + if err != nil { + c.logf("sending c2n ping reply: %v", err) + return + } + resp.Body.Close() + }() + return true +} + +// Close closes the client and releases resources. +func (c *Client) Close() error { + c.mu.Lock() + nc := c.nc + c.nc = nil + c.mu.Unlock() + if nc != nil { + nc.Close() + } + return nil +} + +func defaultHostinfo() *tailcfg.Hostinfo { + return &tailcfg.Hostinfo{ + OS: version.OS(), + IPNVersion: version.Long(), + } +} diff --git a/derp/client_test.go b/derp/client_test.go index e1bcaba8bf2c8..697aa49617925 100644 --- a/derp/client_test.go +++ b/derp/client_test.go @@ -6,7 +6,6 @@ package derp import ( "bufio" "bytes" - "io" "net" "reflect" "sync" @@ -126,36 +125,6 @@ func TestClientSendPong(t *testing.T) { } } -func BenchmarkWriteUint32(b *testing.B) { - w := bufio.NewWriter(io.Discard) - b.ReportAllocs() - b.ResetTimer() - for range b.N { - writeUint32(w, 0x0ba3a) - } -} - -type nopRead struct{} - -func (r nopRead) Read(p []byte) (int, error) { - return len(p), nil -} - -var sinkU32 uint32 - -func BenchmarkReadUint32(b *testing.B) { - r := bufio.NewReader(nopRead{}) - var err error - b.ReportAllocs() - b.ResetTimer() - for range b.N { - sinkU32, err = readUint32(r) - if err != nil { - b.Fatal(err) - } - } -} - type countWriter struct { mu sync.Mutex writes int diff --git a/derp/derp.go b/derp/derp.go index a7d0ea80191a8..a478c16a06911 100644 --- a/derp/derp.go +++ b/derp/derp.go @@ -27,7 +27,7 @@ import ( // including its on-wire framing overhead) const MaxPacketSize = 64 << 10 -// Magic is the DERP Magic number, sent in the frameServerKey frame +// Magic is the DERP Magic number, sent in the FrameServerKey frame // upon initial connection. const Magic = "DERP🔑" // 8 bytes: 0x44 45 52 50 f0 9f 94 91 @@ -58,15 +58,15 @@ Protocol flow: Login: * client connects -* server sends frameServerKey -* client sends frameClientInfo -* server sends frameServerInfo +* server sends FrameServerKey +* client sends FrameClientInfo +* server sends FrameServerInfo Steady state: -* server occasionally sends frameKeepAlive (or framePing) -* client responds to any framePing with a framePong -* client sends frameSendPacket -* server then sends frameRecvPacket to recipient +* server occasionally sends FrameKeepAlive (or FramePing) +* client responds to any FramePing with a FramePong +* client sends FrameSendPacket +* server then sends FrameRecvPacket to recipient */ const ( FrameServerKey = FrameType(0x01) // 8B magic + 32B public key + (0+ bytes future use) @@ -78,16 +78,16 @@ const ( FrameKeepAlive = FrameType(0x06) // no payload, no-op (to be replaced with ping/pong) FrameNotePreferred = FrameType(0x07) // 1 byte payload: 0x01 or 0x00 for whether this is client's home node - // framePeerGone is sent from server to client to signal that + // FramePeerGone is sent from server to client to signal that // a previous sender is no longer connected. That is, if A // sent to B, and then if A disconnects, the server sends - // framePeerGone to B so B can forget that a reverse path + // FramePeerGone to B so B can forget that a reverse path // exists on that connection to get back to A. It is also sent // if A tries to send a CallMeMaybe to B and the server has no // record of B FramePeerGone = FrameType(0x08) // 32B pub key of peer that's gone + 1 byte reason - // framePeerPresent is like framePeerGone, but for other members of the DERP + // FramePeerPresent is like FramePeerGone, but for other members of the DERP // region when they're meshed up together. // // The message is at least 32 bytes (the public key of the peer that's @@ -98,15 +98,15 @@ const ( // servers might send more. FramePeerPresent = FrameType(0x09) - // frameWatchConns is how one DERP node in a regional mesh + // FrameWatchConns is how one DERP node in a regional mesh // subscribes to the others in the region. // There's no payload. If the sender doesn't have permission, the connection // is closed. Otherwise, the client is initially flooded with - // framePeerPresent for all connected nodes, and then a stream of - // framePeerPresent & framePeerGone has peers connect and disconnect. + // FramePeerPresent for all connected nodes, and then a stream of + // FramePeerPresent & FramePeerGone has peers connect and disconnect. FrameWatchConns = FrameType(0x10) - // frameClosePeer is a privileged frame type (requires the + // FrameClosePeer is a privileged frame type (requires the // mesh key for now) that closes the provided peer's // connection. (To be used for cluster load balancing // purposes, when clients end up on a non-ideal node) @@ -115,14 +115,14 @@ const ( FramePing = FrameType(0x12) // 8 byte ping payload, to be echoed back in framePong FramePong = FrameType(0x13) // 8 byte payload, the contents of the ping being replied to - // frameHealth is sent from server to client to tell the client + // FrameHealth is sent from server to client to tell the client // if their connection is unhealthy somehow. Currently the only unhealthy state // is whether the connection is detected as a duplicate. // The entire frame body is the text of the error message. An empty message // clears the error state. FrameHealth = FrameType(0x14) - // frameRestarting is sent from server to client for the + // FrameRestarting is sent from server to client for the // server to declare that it's restarting. Payload is two big // endian uint32 durations in milliseconds: when to reconnect, // and how long to try total. See ServerRestartingMessage docs for @@ -140,7 +140,7 @@ const ( PeerGoneReasonMeshConnBroke = PeerGoneReasonType(0xf0) // invented by Client.RunWatchConnectionLoop on disconnect; not sent on the wire ) -// PeerPresentFlags is an optional byte of bit flags sent after a framePeerPresent message. +// PeerPresentFlags is an optional byte of bit flags sent after a FramePeerPresent message. // // For a modern server, the value should always be non-zero. If the value is zero, // that means the server doesn't support this field. @@ -168,6 +168,8 @@ const FastStartHeader = "Derp-Fast-Start" var bin = binary.BigEndian +// writeUint32 writes v to bw one byte at a time +// as a big-endian uint32. func writeUint32(bw *bufio.Writer, v uint32) error { var b [4]byte bin.PutUint32(b[:], v) @@ -183,21 +185,6 @@ func writeUint32(bw *bufio.Writer, v uint32) error { return nil } -func readUint32(br *bufio.Reader) (uint32, error) { - var b [4]byte - // Reading a byte at a time is a bit silly, - // but it causes b not to escape, - // which more than pays for the silliness. - for i := range &b { - c, err := br.ReadByte() - if err != nil { - return 0, err - } - b[i] = c - } - return bin.Uint32(b[:]), nil -} - // ReadFrameTypeHeader reads a frame header from br and // verifies that the frame type matches wantType. // @@ -213,18 +200,16 @@ func ReadFrameTypeHeader(br *bufio.Reader, wantType FrameType) (frameLen uint32, return frameLen, err } -// ReadFrameHeader reads the header of a DERP frame, -// reading 5 bytes from br. +// ReadFrameHeader reads a DERP frame header ([FrameHeaderLen] bytes) from br. +// It uses Peek+Discard to read directly from bufio's internal buffer +// without copying or allocating. func ReadFrameHeader(br *bufio.Reader) (t FrameType, frameLen uint32, err error) { - tb, err := br.ReadByte() - if err != nil { - return 0, 0, err - } - frameLen, err = readUint32(br) + hdr, err := br.Peek(FrameHeaderLen) if err != nil { return 0, 0, err } - return FrameType(tb), frameLen, nil + defer br.Discard(FrameHeaderLen) + return FrameType(hdr[0]), bin.Uint32(hdr[1:FrameHeaderLen]), nil } // readFrame reads a frame header and then reads its payload into @@ -260,14 +245,26 @@ func readFrame(br *bufio.Reader, maxSize uint32, b []byte) (t FrameType, frameLe return t, frameLen, err } -// WriteFrameHeader writes a frame header to bw. -// -// The frame header is 5 bytes: a one byte frame type -// followed by a big-endian uint32 length of the -// remaining frame (not including the 5 byte header). +// WriteFrameHeader writes a DERP frame header to bw: a one-byte frame +// type followed by a big-endian uint32 frame length. // +// It uses AvailableBuffer to append the header directly into bufio's +// internal buffer without allocation, falling back to WriteByte when +// the buffer has insufficient space. // It does not flush bw. func WriteFrameHeader(bw *bufio.Writer, t FrameType, frameLen uint32) error { + // Fast path: enough space in the buffer to append the header + // directly without allocation via AvailableBuffer. + if bw.Available() >= FrameHeaderLen { + buf := bw.AvailableBuffer() + buf = append(buf, byte(t)) + buf = bin.AppendUint32(buf, frameLen) + _, err := bw.Write(buf) + return err + } + // Slow path: buffer nearly full. Write byte-at-a-time to let + // bufio flush as needed, avoiding a heap allocation from append + // growing past AvailableBuffer's capacity. if err := bw.WriteByte(byte(t)); err != nil { return err } diff --git a/derp/derp_client.go b/derp/derp_client.go index 1e9d48e1456c8..f786179d62833 100644 --- a/derp/derp_client.go +++ b/derp/derp_client.go @@ -554,7 +554,7 @@ func (c *Client) recvTimeout(timeout time.Duration) (m ReceivedMessage, err erro return sm, nil case FrameKeepAlive: // A one-way keep-alive message that doesn't require an acknowledgement. - // This predated framePing/framePong. + // This predated FramePing/FramePong. return KeepAliveMessage{}, nil case FramePeerGone: if n < KeyLen { diff --git a/derp/derp_test.go b/derp/derp_test.go index cff069dd4470c..0edbaff170406 100644 --- a/derp/derp_test.go +++ b/derp/derp_test.go @@ -34,6 +34,165 @@ type ( Client = derp.Client ) +func TestReadFrameHeader(t *testing.T) { + tests := []struct { + name string + input [5]byte + wantType derp.FrameType + wantLen uint32 + }{ + { + name: "SendPacket", + input: [5]byte{byte(derp.FrameSendPacket), 0x00, 0x00, 0x04, 0x00}, + wantType: derp.FrameSendPacket, + wantLen: 1024, + }, + { + name: "KeepAlive", + input: [5]byte{byte(derp.FrameKeepAlive), 0x00, 0x00, 0x00, 0x00}, + wantType: derp.FrameKeepAlive, + wantLen: 0, + }, + { + name: "MaxLen", + input: [5]byte{byte(derp.FrameRecvPacket), 0xff, 0xff, 0xff, 0xff}, + wantType: derp.FrameRecvPacket, + wantLen: 0xffffffff, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + br := bufio.NewReader(bytes.NewReader(tt.input[:])) + gotType, gotLen, err := derp.ReadFrameHeader(br) + if err != nil { + t.Fatalf("ReadFrameHeader: %v", err) + } + if gotType != tt.wantType { + t.Errorf("type = %v, want %v", gotType, tt.wantType) + } + if gotLen != tt.wantLen { + t.Errorf("len = %v, want %v", gotLen, tt.wantLen) + } + }) + } + + // Verify zero allocations. + buf := make([]byte, 4096) + rd := bytes.NewReader(buf) + br := bufio.NewReader(rd) + got := testing.AllocsPerRun(1000, func() { + rd.Reset(buf) + br.Reset(rd) + _, _, err := derp.ReadFrameHeader(br) + if err != nil { + t.Fatalf("ReadFrameHeader: %v", err) + } + }) + if got != 0 { + t.Fatalf("ReadFrameHeader allocs = %f, want 0", got) + } +} + +func TestWriteFrameHeader(t *testing.T) { + tests := []struct { + name string + typ derp.FrameType + frameLen uint32 + want [derp.FrameHeaderLen]byte + }{ + { + name: "SendPacket", + typ: derp.FrameSendPacket, + frameLen: 1024, + want: [derp.FrameHeaderLen]byte{byte(derp.FrameSendPacket), 0x00, 0x00, 0x04, 0x00}, + }, + { + name: "KeepAlive", + typ: derp.FrameKeepAlive, + frameLen: 0, + want: [derp.FrameHeaderLen]byte{byte(derp.FrameKeepAlive), 0x00, 0x00, 0x00, 0x00}, + }, + { + name: "MaxLen", + typ: derp.FrameRecvPacket, + frameLen: 0xffffffff, + want: [derp.FrameHeaderLen]byte{byte(derp.FrameRecvPacket), 0xff, 0xff, 0xff, 0xff}, + }, + } + for _, tt := range tests { + // Test fast path (empty buffer, plenty of space). + t.Run(tt.name+"/fast", func(t *testing.T) { + var buf bytes.Buffer + bw := bufio.NewWriter(&buf) + if err := derp.WriteFrameHeader(bw, tt.typ, tt.frameLen); err != nil { + t.Fatalf("WriteFrameHeader: %v", err) + } + bw.Flush() + if got := buf.Bytes(); !bytes.Equal(got, tt.want[:]) { + t.Errorf("wrote % 02x, want % 02x", got, tt.want) + } + }) + + // Test slow path (buffer nearly full, less than FrameHeaderLen available). + t.Run(tt.name+"/slow", func(t *testing.T) { + var buf bytes.Buffer + const smallBuf = 8 // small enough to force slow path + bw := bufio.NewWriterSize(&buf, smallBuf) + // Fill buffer to leave less than FrameHeaderLen bytes available. + padding := make([]byte, smallBuf-derp.FrameHeaderLen+1) + if _, err := bw.Write(padding); err != nil { + t.Fatalf("Write padding: %v", err) + } + if err := derp.WriteFrameHeader(bw, tt.typ, tt.frameLen); err != nil { + t.Fatalf("WriteFrameHeader: %v", err) + } + bw.Flush() + got := buf.Bytes() + // The header is after the padding bytes. + got = got[len(padding):] + if !bytes.Equal(got, tt.want[:]) { + t.Errorf("wrote % 02x, want % 02x", got, tt.want) + } + }) + } + + // Verify zero allocations on fast path. + bw := bufio.NewWriter(io.Discard) + got := testing.AllocsPerRun(1000, func() { + if err := derp.WriteFrameHeader(bw, derp.FrameSendPacket, 1024); err != nil { + t.Fatalf("WriteFrameHeader: %v", err) + } + }) + if got != 0 { + t.Fatalf("WriteFrameHeader allocs = %f, want 0", got) + } +} + +type nopRead struct{} + +func (nopRead) Read(p []byte) (int, error) { return len(p), nil } + +func BenchmarkReadFrameHeader(b *testing.B) { + r := bufio.NewReader(nopRead{}) + b.ReportAllocs() + for b.Loop() { + _, _, err := derp.ReadFrameHeader(r) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkWriteFrameHeader(b *testing.B) { + bw := bufio.NewWriter(io.Discard) + b.ReportAllocs() + for b.Loop() { + if err := derp.WriteFrameHeader(bw, derp.FrameSendPacket, 1024); err != nil { + b.Fatal(err) + } + } +} + func TestClientInfoUnmarshal(t *testing.T) { for i, in := range map[string]struct { json string @@ -121,8 +280,7 @@ func TestSendRecv(t *testing.T) { } defer cin.Close() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() brwServer := bufio.NewReadWriter(bufio.NewReader(cin), bufio.NewWriter(cin)) go s.Accept(ctx, cin, brwServer, fmt.Sprintf("[abc::def]:%v", i)) @@ -195,7 +353,7 @@ func TestSendRecv(t *testing.T) { } } - serverMetrics := s.ExpVar().(*metrics.Set) + serverMetrics := s.ExpVar(false).(*metrics.Set) wantActive := func(total, home int64) { t.Helper() @@ -331,8 +489,7 @@ func TestSendFreeze(t *testing.T) { return c, c2 } - ctx, clientCtxCancel := context.WithCancel(context.Background()) - defer clientCtxCancel() + ctx := t.Context() aliceKey := key.NewNode() aliceClient, aliceConn := newClient(ctx, "alice", aliceKey) @@ -459,13 +616,13 @@ func TestSendFreeze(t *testing.T) { } } - t.Run("initial send", func(t *testing.T) { + t.Run("initial-send", func(t *testing.T) { drain(t, "bob") drain(t, "cathy") isEmpty(t, "alice") }) - t.Run("block cathy", func(t *testing.T) { + t.Run("block-cathy", func(t *testing.T) { // Block cathy. Now the cathyConn buffer will fill up quickly, // and the derp server will back up. cathyConn.SetReadBlock(true) @@ -716,8 +873,7 @@ func (c *testClient) close(t *testing.T) { // TestWatch tests the connection watcher mechanism used by regional // DERP nodes to mesh up with each other. func TestWatch(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() ts := newTestServer(t, ctx) defer ts.close(t) @@ -764,8 +920,7 @@ func waitConnect(t testing.TB, c *Client) { } func TestServerRepliesToPing(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() ts := newTestServer(t, ctx) defer ts.close(t) diff --git a/derp/derphttp/derphttp_test.go b/derp/derphttp/derphttp_test.go index ae530c93a31c0..304906b749257 100644 --- a/derp/derphttp/derphttp_test.go +++ b/derp/derphttp/derphttp_test.go @@ -299,9 +299,7 @@ func TestBreakWatcherConnRecv(t *testing.T) { errChan := make(chan error, 1) // Start the watcher thread (which connects to the watched server) - wg.Add(1) // To avoid using t.Logf after the test ends. See https://golang.org/issue/40343 - go func() { - defer wg.Done() + wg.Go(func() { var peers int add := func(m derp.PeerPresentMessage) { t.Logf("add: %v", m.Key.ShortString()) @@ -318,7 +316,7 @@ func TestBreakWatcherConnRecv(t *testing.T) { } watcher.RunWatchConnectionLoop(ctx, serverPrivateKey1.Public(), t.Logf, add, remove, notifyErr) - }() + }) synctest.Wait() @@ -381,9 +379,7 @@ func TestBreakWatcherConn(t *testing.T) { errorChan := make(chan error, 1) // Start the watcher thread (which connects to the watched server) - wg.Add(1) // To avoid using t.Logf after the test ends. See https://golang.org/issue/40343 - go func() { - defer wg.Done() + wg.Go(func() { var peers int add := func(m derp.PeerPresentMessage) { t.Logf("add: %v", m.Key.ShortString()) @@ -403,7 +399,7 @@ func TestBreakWatcherConn(t *testing.T) { } watcher1.RunWatchConnectionLoop(ctx, serverPrivateKey1.Public(), t.Logf, add, remove, notifyError) - }() + }) synctest.Wait() diff --git a/derp/derpserver/derpserver.go b/derp/derpserver/derpserver.go index f311eb25d9817..4e60fff675d07 100644 --- a/derp/derpserver/derpserver.go +++ b/derp/derpserver/derpserver.go @@ -30,6 +30,7 @@ import ( "os" "os/exec" "runtime" + "slices" "strconv" "strings" "sync" @@ -39,6 +40,7 @@ import ( "github.com/axiomhq/hyperloglog" "go4.org/mem" "golang.org/x/sync/errgroup" + xrate "golang.org/x/time/rate" "tailscale.com/client/local" "tailscale.com/derp" "tailscale.com/derp/derpconst" @@ -51,6 +53,7 @@ import ( "tailscale.com/tstime/rate" "tailscale.com/types/key" "tailscale.com/types/logger" + "tailscale.com/util/bufiox" "tailscale.com/util/ctxkey" "tailscale.com/util/mak" "tailscale.com/util/set" @@ -71,7 +74,7 @@ func init() { if keys == "" { return } - for _, keyStr := range strings.Split(keys, ",") { + for keyStr := range strings.SplitSeq(keys, ",") { k, err := key.ParseNodePublicUntyped(mem.S(keyStr)) if err != nil { log.Printf("ignoring invalid debug key %q: %v", keyStr, err) @@ -169,6 +172,8 @@ type Server struct { meshUpdateBatchSize *metrics.Histogram meshUpdateLoopCount *metrics.Histogram bufferedWriteFrames *metrics.Histogram // how many sendLoop frames (or groups of related frames) get written per flush + rateLimitPerClientWaited expvar.Int // number of times per-client rate limit caused a wait + // TODO(illotum): add metrics for rate limited wait time, consider total seconds vs a histogram. // verifyClientsLocalTailscaled only accepts client connections to the DERP // server if the clientKey is a known peer in the network, as specified by a @@ -178,7 +183,11 @@ type Server struct { verifyClientsURL string verifyClientsURLFailOpen bool - mu syncs.Mutex + perClientSendQueueDepth int // Sets the client send queue depth for the server. + tcpWriteTimeout time.Duration + clock tstime.Clock + + mu syncs.Mutex // guards the following fields closed bool netConns map[derp.Conn]chan struct{} // chan is closed when conn closes clients map[key.NodePublic]*clientSet @@ -194,16 +203,9 @@ type Server struct { // is gone from the region, we notify all of these watchers, // calling their funcs in a new goroutine. peerGoneWatchers map[key.NodePublic]set.HandleSet[func(key.NodePublic)] - // maps from netip.AddrPort to a client's public key - keyOfAddr map[netip.AddrPort]key.NodePublic - - // Sets the client send queue depth for the server. - perClientSendQueueDepth int - - tcpWriteTimeout time.Duration - - clock tstime.Clock + keyOfAddr map[netip.AddrPort]key.NodePublic + rateConfig RateConfig // per-client DERP frame rate limiting config } // clientSet represents 1 or more *sclients. @@ -227,10 +229,6 @@ type clientSet struct { // activeClient holds the currently active connection for the set. It's nil // if there are no connections or the connection is disabled. // - // A pointer to a clientSet can be held by peers for long periods of time - // without holding Server.mu to avoid mutex contention on Server.mu, only - // re-acquiring the mutex and checking the clients map if activeClient is - // nil. activeClient atomic.Pointer[sclient] // dup is non-nil if there are multiple connections for the @@ -506,6 +504,79 @@ func (s *Server) SetTCPWriteTimeout(d time.Duration) { s.tcpWriteTimeout = d } +// minRateLimitTokenBucketSize represents the minimum size of a token bucket +// applied for the purposes of rate limiting a DERP connection per received DERP +// frame. +// +// Note: The DERP protocol supports frames larger than this ([math.MaxUint32] length), +// but a [derp.FrameSendPacket] cannot exceed this value, which is what we optimize +// our token bucket calls for. +const minRateLimitTokenBucketSize = derp.MaxPacketSize + derp.KeyLen + +// RateConfig is a JSON-serializable configuration for rate limits. Values are +// in bytes. +type RateConfig struct { + // PerClientRateLimitBytesPerSec represents the per-client + // rate limit in bytes per second. A zero value disables all rate limiting. + PerClientRateLimitBytesPerSec uint64 `json:",omitzero"` + // PerClientRateBurstBytes represents the per-client token bucket depth, + // or burst, in bytes. Any value lower than [minRateLimitTokenBucketSize] + // will be increased to [minRateLimitTokenBucketSize] before application. Only + // relevant if PerClientRateLimitBytesPerSec is nonzero. + PerClientRateBurstBytes uint64 `json:",omitzero"` +} + +// LoadRateConfig reads and JSON-unmarshals a [RateConfig] from the file at path. +func LoadRateConfig(path string) (RateConfig, error) { + if path == "" { + return RateConfig{}, errors.New("rate config path is empty") + } + b, err := os.ReadFile(path) + if err != nil { + return RateConfig{}, fmt.Errorf("error reading rate config: %w", err) + } + var rc RateConfig + if err := json.Unmarshal(b, &rc); err != nil { + return RateConfig{}, fmt.Errorf("error parsing rate config: %w", err) + } + return rc, nil +} + +// LoadAndApplyRateConfig reads a [RateConfig] from the file at path and +// applies it to the server via [Server.UpdateRateLimits]. +func (s *Server) LoadAndApplyRateConfig(path string) error { + rc, err := LoadRateConfig(path) + if err != nil { + return err + } + applied := s.UpdateRateLimits(rc) + s.logf("rate config applied: client-rate=%d bytes/sec, client-burst=%d bytes", + applied.PerClientRateLimitBytesPerSec, applied.PerClientRateBurstBytes) + return nil +} + +// UpdateRateLimits sets the receive rate limits, updating all existing client +// connections. It returns the applied config, which may differ from rc. If the +// per-client rate limits is 0, rate limiting is disabled. Mesh peers are always +// exempt from rate limiting. +func (s *Server) UpdateRateLimits(rc RateConfig) (applied RateConfig) { + s.mu.Lock() + defer s.mu.Unlock() + if rc.PerClientRateLimitBytesPerSec == 0 { + // all rate limiting is disabled + rc = RateConfig{} + } else { + rc.PerClientRateBurstBytes = max(rc.PerClientRateBurstBytes, minRateLimitTokenBucketSize) + } + s.rateConfig = rc + for _, cs := range s.clients { + cs.ForeachClient(func(c *sclient) { + c.setRateLimit(rc.PerClientRateLimitBytesPerSec, rc.PerClientRateBurstBytes) + }) + } + return rc +} + // HasMeshKey reports whether the server is configured with a mesh key. func (s *Server) HasMeshKey() bool { return !s.meshKey.IsZero() } @@ -668,6 +739,8 @@ func (s *Server) registerClient(c *sclient) { s.mu.Lock() defer s.mu.Unlock() + c.setRateLimit(s.rateConfig.PerClientRateLimitBytesPerSec, s.rateConfig.PerClientRateBurstBytes) + cs, ok := s.clients[c.key] if !ok { c.debugLogf("register single client") @@ -941,7 +1014,7 @@ func (s *Server) accept(ctx context.Context, nc derp.Conn, brw *bufio.ReadWriter br: br, bw: bw, logf: logger.WithPrefix(s.logf, fmt.Sprintf("derp client %v%s: ", remoteAddr, clientKey.ShortString())), - done: ctx.Done(), + ctx: ctx, remoteIPPort: remoteIPPort, connectedAt: s.clock.Now(), sendQueue: make(chan pkt, s.perClientSendQueueDepth), @@ -1004,7 +1077,12 @@ func (c *sclient) run(ctx context.Context) error { } }() - c.startStatsLoop(sendCtx) + // Allow disabling RTT stats collection to reduce + // CPU and syscalls on servers with high connection + // counts + if !envknob.Bool("TS_DERP_DISABLE_RTT_STATS") { + c.startStatsLoop(sendCtx) + } for { ft, fl, err := derp.ReadFrameHeader(c.br) @@ -1020,6 +1098,13 @@ func (c *sclient) run(ctx context.Context) error { } return fmt.Errorf("client %s: readFrameHeader: %w", c.key.ShortString(), err) } + // Rate-limit by DERP frame length (fl), which excludes TLS protocol and + // DERP frame length field overheads. + // Note: meshed clients are exempt from rate limits. + if err := c.rateLimit(int(fl)); err != nil { + return err // context canceled, connection closing + } + c.s.noteClientActivity(c) switch ft { case derp.FrameNotePreferred: @@ -1082,13 +1167,14 @@ func (c *sclient) handleFramePing(ft derp.FrameType, fl uint32) error { // space for future extensibility, but not too much. return fmt.Errorf("ping body too large: %v", fl) } - _, err := io.ReadFull(c.br, m[:]) - if err != nil { + if _, err := bufiox.ReadFull(c.br, m[:]); err != nil { return err } + var err error if extra := int64(fl) - int64(len(m)); extra > 0 { _, err = io.CopyN(io.Discard, c.br, extra) } + select { case c.sendPongCh <- [8]byte(m): default: @@ -1132,7 +1218,7 @@ func (c *sclient) handleFrameClosePeer(ft derp.FrameType, fl uint32) error { // handleFrameForwardPacket reads a "forward packet" frame from the client // (which must be a trusted client, a peer in our mesh). -func (c *sclient) handleFrameForwardPacket(ft derp.FrameType, fl uint32) error { +func (c *sclient) handleFrameForwardPacket(_ derp.FrameType, fl uint32) error { if !c.canMesh { return fmt.Errorf("insufficient permissions") } @@ -1175,7 +1261,7 @@ func (c *sclient) handleFrameForwardPacket(ft derp.FrameType, fl uint32) error { } // handleFrameSendPacket reads a "send packet" frame from the client. -func (c *sclient) handleFrameSendPacket(ft derp.FrameType, fl uint32) error { +func (c *sclient) handleFrameSendPacket(_ derp.FrameType, fl uint32) error { s := c.s dstKey, contents, err := s.recvPacket(c.br, fl) @@ -1228,6 +1314,80 @@ func (c *sclient) handleFrameSendPacket(ft derp.FrameType, fl uint32) error { return c.sendPkt(dst, p) } +// setRateLimit updates the receive rate limiter. When bytesPerSec is 0, or the +// client is a mesh peer, the limiter is set to nil so that [sclient.rateLimit] is a no-op. +func (c *sclient) setRateLimit(bytesPerSec, burst uint64) { + if c.canMesh || bytesPerSec == 0 { + c.recvLim.Store(nil) + return + } + limiter := xrate.NewLimiter(xrate.Limit(bytesPerSec), int(burst)) + c.recvLim.Store(limiter) +} + +// rateLimitWait is a reimplementation of [xrate.Limiter.WaitN] via [xrate.Limiter.ReserveN]. +// It returns the duration waited for tokens to become available. +func rateLimitWait(ctx context.Context, lim *xrate.Limiter, n int, now time.Time, newTimer func(time.Duration) (<-chan time.Time, func() bool)) (time.Duration, error) { + r := lim.ReserveN(now, n) + if !r.OK() { + return 0, fmt.Errorf("rate: Wait(n=%d) exceeds limiter's burst %d", n, lim.Burst()) + } + delay := r.DelayFrom(now) + if delay == 0 { + return 0, nil + } + ch, stop := newTimer(delay) + defer stop() + select { + case <-ch: + // Note: We return the predicted delay as wall-clock duration. May be not the same. + return delay, nil + case <-ctx.Done(): + r.Cancel() + return 0, ctx.Err() + } +} + +// rateLimit applies the receive rate limit. +// By limiting here we prevent reading from the buffered reader +// [sclient.br] if the limit has been exceeded. Any reads done here provide space +// within the buffered reader to fill back in with data from +// the TCP socket. Pacing reads acts as a form of natural +// backpressure via TCP flow control. +// When rate limiting is disabled or the client is a mesh peer, recvLim is nil +// and this is a no-op. +func (c *sclient) rateLimit(n int) error { + if lim := c.recvLim.Load(); lim != nil { + newTimer := func(d time.Duration) (<-chan time.Time, func() bool) { + tc, ch := c.s.clock.NewTimer(d) + return ch, tc.Stop + } + // If n exceeds the capacity of the bucket, then WaitN will return + // an error and consume zero tokens. To prevent this, clamp n to + // [minRateLimitTokenBucketSize]. + // + // While we could call WaitN multiple times and/or more precisely for + // lim.Burst(), it's better to return early as a larger DERP frame: + // 1. is unexpected + // 2. is only partially read off the socket (bufio) + // 3. would cause the connection to close shortly after rate limiting, anyway. + clampedN := min(n, minRateLimitTokenBucketSize) + now := c.s.clock.Now() + var ( + durationWaited time.Duration + err error + ) + durationWaited, err = rateLimitWait(c.ctx, lim, clampedN, now, newTimer) + if err != nil { + return err + } + if durationWaited > 0 { + c.s.rateLimitPerClientWaited.Add(1) + } + } + return nil +} + func (c *sclient) debugLogf(format string, v ...any) { if c.debug { c.logf(format, v...) @@ -1287,9 +1447,9 @@ func (c *sclient) sendPkt(dst *sclient, p pkt) error { if disco.LooksLikeDiscoWrapper(p.bs) { sendQueue = dst.discoSendQueue } - for attempt := 0; attempt < 3; attempt++ { + for attempt := range 3 { select { - case <-dst.done: + case <-dst.ctx.Done(): s.recordDrop(p.bs, c.key, dstKey, dropReasonGoneDisconnected) dst.debugLogf("sendPkt attempt %d dropped, dst gone", attempt) return nil @@ -1334,7 +1494,7 @@ func (c *sclient) requestPeerGoneWrite(peer key.NodePublic, reason derp.PeerGone peer: peer, reason: reason, }: - case <-c.done: + case <-c.ctx.Done(): } } @@ -1484,16 +1644,13 @@ func (s *Server) noteClientActivity(c *sclient) { // If we saw this connection send previously, then consider // the group fighting and disable them all. if s.dupPolicy == disableFighters { - for _, prior := range dup.sendHistory { - if prior == c { - cs.ForeachClient(func(c *sclient) { - c.isDisabled.Store(true) - if cs.activeClient.Load() == c { - cs.activeClient.Store(nil) - } - }) - break - } + if slices.Contains(dup.sendHistory, c) { + cs.ForeachClient(func(c *sclient) { + c.isDisabled.Store(true) + if cs.activeClient.Load() == c { + cs.activeClient.Store(nil) + } + }) } } @@ -1519,7 +1676,7 @@ func (s *Server) sendServerInfo(bw *lazyBufioWriter, clientKey key.NodePublic) e return bw.Flush() } -// recvClientKey reads the frameClientInfo frame from the client (its +// recvClientKey reads the FrameClientInfo frame from the client (its // proof of identity) upon its initial connection. It should be // considered especially untrusted at this point. func (s *Server) recvClientKey(br *bufio.Reader) (clientKey key.NodePublic, info *derp.ClientInfo, err error) { @@ -1622,7 +1779,7 @@ type sclient struct { key key.NodePublic info derp.ClientInfo logf logger.Logf - done <-chan struct{} // closed when connection closes + ctx context.Context // closed when connection closes remoteIPPort netip.AddrPort // zero if remoteAddr is not ip:port. sendQueue chan pkt // packets queued to this client; never closed discoSendQueue chan pkt // important packets queued to this client; never closed @@ -1662,6 +1819,16 @@ type sclient struct { // client that it's trying to establish a direct connection // through us with a peer we have no record of. peerGoneLim *rate.Limiter + + // recvLim is the receive rate limiter. When rate limiting is enabled for a + // non-mesh client, it points to a [xrate.Limiter]. When rate limiting + // is disabled or the client is a mesh peer, it is nil and [sclient.rateLimit] + // is a no-op. Updated atomically by [sclient.setRateLimit] so that + // [sclient.rateLimit] can load it without holding [Server.mu]. + // + // TODO: consider porting the required APIs from [xrate.Limiter] to [rate.Limiter], + // which is already optimized to use [mono.Time]. + recvLim atomic.Pointer[xrate.Limiter] } func (c *sclient) presentFlags() derp.PeerPresentFlags { @@ -2202,7 +2369,7 @@ func (s *Server) expVarFunc(f func() any) expvar.Func { } // ExpVar returns an expvar variable suitable for registering with expvar.Publish. -func (s *Server) ExpVar() expvar.Var { +func (s *Server) ExpVar(rateLimitEnabled bool) expvar.Var { m := new(metrics.Set) m.Set("gauge_memstats_sys0", expvar.Func(func() any { return int64(s.memSys0) })) m.Set("gauge_watchers", s.expVarFunc(func() any { return len(s.watchers) })) @@ -2210,9 +2377,9 @@ func (s *Server) ExpVar() expvar.Var { m.Set("gauge_current_connections", &s.curClients) m.Set("gauge_current_home_connections", &s.curHomeClients) m.Set("gauge_current_notideal_connections", &s.curClientsNotIdeal) - m.Set("gauge_clients_total", expvar.Func(func() any { return len(s.clientsMesh) })) - m.Set("gauge_clients_local", expvar.Func(func() any { return len(s.clients) })) - m.Set("gauge_clients_remote", expvar.Func(func() any { return len(s.clientsMesh) - len(s.clients) })) + m.Set("gauge_clients_total", s.expVarFunc(func() any { return len(s.clientsMesh) })) + m.Set("gauge_clients_local", s.expVarFunc(func() any { return len(s.clients) })) + m.Set("gauge_clients_remote", s.expVarFunc(func() any { return len(s.clientsMesh) - len(s.clients) })) m.Set("gauge_current_dup_client_keys", &s.dupClientKeys) m.Set("gauge_current_dup_client_conns", &s.dupClientConns) m.Set("counter_total_dup_client_conns", &s.dupClientConnTotal) @@ -2245,6 +2412,18 @@ func (s *Server) ExpVar() expvar.Var { var expvarVersion expvar.String expvarVersion.Set(version.Long()) m.Set("version", &expvarVersion) + if rateLimitEnabled { + // Rate limiting is currently experimental, its APIs are unstable, and it must + // be opted-in via --rate-config. Therefore, we only publish related metrics + // on demand, to avoid polluting uninterested metrics consumers. + m.Set("rate_limit_per_client_bytes_per_second", s.expVarFunc(func() any { + return s.rateConfig.PerClientRateLimitBytesPerSec + })) + m.Set("rate_limit_per_client_burst_bytes", s.expVarFunc(func() any { + return s.rateConfig.PerClientRateBurstBytes + })) + m.Set("rate_limit_per_client_waited", &s.rateLimitPerClientWaited) + } return m } diff --git a/derp/derpserver/derpserver_test.go b/derp/derpserver/derpserver_test.go index 3a778d59fb009..7143a9b3d62e4 100644 --- a/derp/derpserver/derpserver_test.go +++ b/derp/derpserver/derpserver_test.go @@ -15,10 +15,12 @@ import ( "log" "net" "os" + "path/filepath" "reflect" "strconv" "sync" "testing" + "testing/synctest" "time" "github.com/axiomhq/hyperloglog" @@ -29,6 +31,7 @@ import ( "tailscale.com/derp/derpconst" "tailscale.com/types/key" "tailscale.com/types/logger" + "tailscale.com/util/set" ) const testMeshKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" @@ -627,22 +630,17 @@ func BenchmarkConcurrentStreams(b *testing.B) { if err != nil { b.Fatal(err) } - defer ln.Close() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := b.Context() + acceptDone := make(chan struct{}) go func() { - for ctx.Err() == nil { + defer close(acceptDone) + for { connIn, err := ln.Accept() if err != nil { - if ctx.Err() != nil { - return - } - b.Error(err) return } - brwServer := bufio.NewReadWriter(bufio.NewReader(connIn), bufio.NewWriter(connIn)) go s.Accept(ctx, connIn, brwServer, "test-client") } @@ -680,6 +678,9 @@ func BenchmarkConcurrentStreams(b *testing.B) { } } }) + + ln.Close() + <-acceptDone } func BenchmarkSendRecv(b *testing.B) { @@ -769,7 +770,7 @@ func TestServeDebugTrafficUniqueSenders(t *testing.T) { senderCardinality: hyperloglog.New(), } - for i := 0; i < 5; i++ { + for range 5 { c.senderCardinality.Insert(key.NewNode().Public().AppendTo(nil)) } @@ -845,7 +846,7 @@ func TestSenderCardinality(t *testing.T) { t.Errorf("EstimatedUniqueSenders() = %d, want ~10 (8-12 range)", estimate) } - for i := 0; i < 5; i++ { + for i := range 5 { c.senderCardinality.Insert(senders[i].AppendTo(nil)) } @@ -869,7 +870,7 @@ func TestSenderCardinality100(t *testing.T) { } numSenders := 100 - for i := 0; i < numSenders; i++ { + for range numSenders { c.senderCardinality.Insert(key.NewNode().Public().AppendTo(nil)) } @@ -945,7 +946,7 @@ func BenchmarkHyperLogLogInsertUnique(b *testing.B) { func BenchmarkHyperLogLogEstimate(b *testing.B) { hll := hyperloglog.New() - for i := 0; i < 100; i++ { + for range 100 { hll.Insert(key.NewNode().Public().AppendTo(nil)) } @@ -955,6 +956,495 @@ func BenchmarkHyperLogLogEstimate(b *testing.B) { } } +func TestPerClientRateLimit(t *testing.T) { + t.Run("throttled", func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + s := New(key.NewNode(), logger.Discard) + defer s.Close() + + c := &sclient{ + ctx: ctx, + s: s, + } + lim := rate.NewLimiter(rate.Limit(minRateLimitTokenBucketSize), minRateLimitTokenBucketSize) + c.recvLim.Store(lim) + wantTokens := func(t *testing.T, wantTokens float64) { + t.Helper() + if lim.Tokens() != wantTokens { + t.Fatalf("want tokens: %v got: %v", wantTokens, lim.Tokens()) + } + } + + // First call within burst should not block. + c.rateLimit(minRateLimitTokenBucketSize) + + wantTokens(t, 0) + + // Next call exceeds burst, should block until tokens replenish. + done := make(chan error, 1) + go func() { + done <- c.rateLimit(minRateLimitTokenBucketSize) + }() + + // After settling, the goroutine should be blocked (no result yet). + synctest.Wait() + select { + case err := <-done: + t.Fatalf("rateLimit should have blocked, but returned: %v", err) + default: + } + + // Advance time by 1 second, the goroutine should be unblocked + time.Sleep(1 * time.Second) + synctest.Wait() + + select { + case err := <-done: + if err != nil { + t.Fatalf("rateLimit after time advance: %v", err) + } + default: + t.Fatal("rateLimit should have unblocked after 1s") + } + + wantTokens(t, 0) + + // The second rateLimit call had to wait + if got := s.rateLimitPerClientWaited.Value(); got != 1 { + t.Fatalf("rateLimitPerClientWaited = %d, want 1", got) + } + }) + }) + + t.Run("context_canceled", func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + s := New(key.NewNode(), logger.Discard) + defer s.Close() + + c := &sclient{ + ctx: ctx, + s: s, + } + lim := rate.NewLimiter(rate.Limit(minRateLimitTokenBucketSize), minRateLimitTokenBucketSize) + c.recvLim.Store(lim) + + // Exhaust burst. + if err := c.rateLimit(minRateLimitTokenBucketSize); err != nil { + t.Fatalf("rateLimit: %v", err) + } + + done := make(chan error, 1) + go func() { + done <- c.rateLimit(minRateLimitTokenBucketSize) + }() + synctest.Wait() + + // Cancel the context; the blocked rateLimit should return an error. + cancel() + synctest.Wait() + + select { + case err := <-done: + if err == nil { + t.Fatal("expected error from canceled context") + } + default: + t.Fatal("rateLimit should have returned after context cancelation") + } + }) + }) + + t.Run("mesh_peer_exempt", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + // Mesh peers have nil recvLim, so rate limiting is a no-op. + c := &sclient{ + ctx: ctx, + canMesh: true, + } + + if err := c.rateLimit(1000); err != nil { + t.Fatalf("mesh peer rateLimit should be no-op: %v", err) + } + }) + + t.Run("zero_config_no_limiter", func(t *testing.T) { + s := New(key.NewNode(), logger.Discard) + defer s.Close() + if !reflect.DeepEqual(s.rateConfig, RateConfig{}) { + t.Errorf("expected zero rate limit, got %+v", s.rateConfig) + } + }) +} + +// zeroTimer returns a timer that fires immediately. +func zeroTimer(_ time.Duration) (<-chan time.Time, func() bool) { + t := time.NewTimer(0) + return t.C, t.Stop +} + +// neverTimer returns a timer that never fires. +func neverTimer(_ time.Duration) (<-chan time.Time, func() bool) { + return make(chan time.Time), func() bool { return false } +} + +func TestRateLimitWait(t *testing.T) { + ctx := context.Background() + + t.Run("no_wait", func(t *testing.T) { + lim := rate.NewLimiter(10, 10) + waited, err := rateLimitWait(ctx, lim, 5, time.Now(), zeroTimer) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if waited != 0 { + t.Fatalf("waited = %v, want 0", waited) + } + }) + + t.Run("wait_for_tokens", func(t *testing.T) { + lim := rate.NewLimiter(10, 10) + now := time.Now() + waited, err := rateLimitWait(ctx, lim, 10, now, zeroTimer) // exhaust all tokens + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if waited != 0 { + t.Fatalf("waited = %v, want 0", waited) + } + waited, err = rateLimitWait(ctx, lim, 10, now, zeroTimer) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if waited == 0 { + t.Fatal("waited = 0, want > 0") + } + }) + + t.Run("context_canceled", func(t *testing.T) { + lim := rate.NewLimiter(10, 10) + now := time.Now() + _, err := rateLimitWait(ctx, lim, 10, now, zeroTimer) // exhaust all tokens + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + canceled, cancel := context.WithCancel(ctx) // cancel context so the select picks ctx.Done() + cancel() + waited, err := rateLimitWait(canceled, lim, 10, now, neverTimer) // neverTimer to only unblock via context + if err == nil { + t.Fatal("expected error from canceled context") + } + if waited != 0 { + t.Fatalf("waited = %v, want 0", waited) + } + }) + + t.Run("n_exceeds_burst", func(t *testing.T) { + lim := rate.NewLimiter(10, 5) + waited, err := rateLimitWait(ctx, lim, 10, time.Now(), zeroTimer) + if err == nil { + t.Fatal("expected error when n > burst") + } + if waited != 0 { + t.Fatalf("waited = %v, want 0", waited) + } + }) +} + +func verifyLimiter(t *testing.T, lim *rate.Limiter, wantRateConfig RateConfig) { + t.Helper() + if got := lim.Limit(); got != rate.Limit(wantRateConfig.PerClientRateLimitBytesPerSec) { + t.Errorf("client rate limit = %v; want %d", got, wantRateConfig.PerClientRateLimitBytesPerSec) + } + if got := lim.Burst(); got != int(wantRateConfig.PerClientRateBurstBytes) { + t.Errorf("client burst = %v; want %d", got, wantRateConfig.PerClientRateBurstBytes) + } +} + +func TestUpdateRateLimits(t *testing.T) { + const ( + testClientBurst1 = minRateLimitTokenBucketSize + 1 + testClientRate1 = minRateLimitTokenBucketSize + 2 + testClientBurst2 = minRateLimitTokenBucketSize + 3 + testClientRate2 = minRateLimitTokenBucketSize + 4 + ) + + s := New(key.NewNode(), t.Logf) + defer s.Close() + + // Create a non-mesh client with no initial limiter. + clientKey := key.NewNode().Public() + c := &sclient{ + key: clientKey, + s: s, + logf: logger.Discard, + canMesh: false, + } + cs := &clientSet{} + cs.activeClient.Store(c) + + s.mu.Lock() + s.clients[clientKey] = cs + s.mu.Unlock() + + rc := RateConfig{ + PerClientRateLimitBytesPerSec: testClientRate1, + PerClientRateBurstBytes: testClientBurst1, + } + s.UpdateRateLimits(rc) + + lim := c.recvLim.Load() + if lim == nil { + t.Fatal("expected non-nil limiter after update") + } + verifyLimiter(t, lim, rc) + + // Verify server fields updated. + s.mu.Lock() + if !reflect.DeepEqual(s.rateConfig, rc) { + t.Errorf("s.rateConfig = %+v; want %+v", s.rateConfig, rc) + } + s.mu.Unlock() + + // Update again with different nonzero values. + rc = RateConfig{ + PerClientRateLimitBytesPerSec: testClientRate2, + PerClientRateBurstBytes: testClientBurst2, + } + s.UpdateRateLimits(rc) + lim = c.recvLim.Load() + if lim == nil { + t.Fatal("expected non-nil limiter") + } + verifyLimiter(t, lim, rc) + + // Disable rate limiting (set to 0). + s.UpdateRateLimits(RateConfig{}) + + if got := c.recvLim.Load(); got != nil { + t.Errorf("expected nil limiter after disable, got limit=%v", got.Limit()) + } + + // Mesh peer should always have nil limiter regardless of update. + meshKey := key.NewNode().Public() + meshClient := &sclient{ + key: meshKey, + s: s, + logf: logger.Discard, + canMesh: true, + } + meshCS := &clientSet{} + meshCS.activeClient.Store(meshClient) + + s.mu.Lock() + s.clients[meshKey] = meshCS + s.mu.Unlock() + + rc = RateConfig{ + PerClientRateLimitBytesPerSec: testClientRate2, + PerClientRateBurstBytes: testClientBurst2, + } + s.UpdateRateLimits(rc) + + if got := meshClient.recvLim.Load(); got != nil { + t.Errorf("mesh peer should have nil limiter, got limit=%v", got.Limit()) + } + // Non-mesh client should be updated. + lim = c.recvLim.Load() + if lim == nil { + t.Fatal("expected non-nil limiter for non-mesh client") + } + verifyLimiter(t, lim, rc) + + // Verify dup clients are also updated. + dupKey := key.NewNode().Public() + d1 := &sclient{key: dupKey, s: s, logf: logger.Discard} + d2 := &sclient{key: dupKey, s: s, logf: logger.Discard} + dupCS := &clientSet{} + dupCS.activeClient.Store(d1) + dupCS.dup = &dupClientSet{set: set.Of(d1, d2)} + s.mu.Lock() + s.clients[dupKey] = dupCS + s.mu.Unlock() + + rc = RateConfig{ + PerClientRateLimitBytesPerSec: testClientRate1, + PerClientRateBurstBytes: testClientBurst1, + } + s.UpdateRateLimits(rc) + for i, d := range []*sclient{d1, d2} { + dl := d.recvLim.Load() + if dl == nil { + t.Fatalf("dup client %d: expected non-nil limiter", i) + } + verifyLimiter(t, dl, rc) + } +} + +func TestLoadRateConfig(t *testing.T) { + for _, tt := range []struct { + name string + json string + wantRateConfig RateConfig + }{ + {"all_set", `{"PerClientRateLimitBytesPerSec": 1, "PerClientRateBurstBytes": 2}`, RateConfig{ + PerClientRateLimitBytesPerSec: 1, + PerClientRateBurstBytes: 2, + }}, + {"rate_only", `{"PerClientRateLimitBytesPerSec": 1}`, RateConfig{ + PerClientRateLimitBytesPerSec: 1, + }}, + {"zeros", `{"PerClientRateLimitBytesPerSec": 0, "PerClientRateBurstBytes": 0}`, RateConfig{}}, + {"empty_json", `{}`, RateConfig{}}, + } { + t.Run(tt.name, func(t *testing.T) { + f := filepath.Join(t.TempDir(), "rate.json") + if err := os.WriteFile(f, []byte(tt.json), 0644); err != nil { + t.Fatal(err) + } + rc, err := LoadRateConfig(f) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(rc, tt.wantRateConfig) { + t.Errorf("rate config = %v want %v", rc, tt.wantRateConfig) + } + }) + } + + for _, tt := range []struct { + name string + path string + content string // written to loaded path if non-empty; path used as-is if empty + }{ + {"empty_path", "", ""}, + {"missing_file", filepath.Join(t.TempDir(), "nonexistent.json"), ""}, + {"invalid_json", "", "not json"}, + } { + t.Run(tt.name, func(t *testing.T) { + path := tt.path + if tt.content != "" { + path = filepath.Join(t.TempDir(), "rate.json") + if err := os.WriteFile(path, []byte(tt.content), 0644); err != nil { + t.Fatal(err) + } + } + _, err := LoadRateConfig(path) + if err == nil { + t.Fatal("expected error") + } + }) + } +} + +func TestLoadAndApplyRateConfig(t *testing.T) { + writeConfig := func(t *testing.T, json string) string { + t.Helper() + f := filepath.Join(t.TempDir(), "rate.json") + if err := os.WriteFile(f, []byte(json), 0644); err != nil { + t.Fatal(err) + } + return f + } + + t.Run("applies_and_updates_clients", func(t *testing.T) { + s := New(key.NewNode(), t.Logf) + defer s.Close() + + clientKey := key.NewNode().Public() + c := &sclient{key: clientKey, s: s, logf: logger.Discard} + cs := &clientSet{} + cs.activeClient.Store(c) + s.mu.Lock() + s.clients[clientKey] = cs + s.mu.Unlock() + + f := writeConfig(t, fmt.Sprintf(`{"PerClientRateLimitBytesPerSec": %d, "PerClientRateBurstBytes": %d}`, + minRateLimitTokenBucketSize, minRateLimitTokenBucketSize+1)) + if err := s.LoadAndApplyRateConfig(f); err != nil { + t.Fatalf("LoadAndApplyRateConfig: %v", err) + } + + // Verify server fields. + wantRateConfig := RateConfig{ + PerClientRateLimitBytesPerSec: minRateLimitTokenBucketSize, + PerClientRateBurstBytes: minRateLimitTokenBucketSize + 1, + } + s.mu.Lock() + if !reflect.DeepEqual(s.rateConfig, wantRateConfig) { + t.Errorf("s.rateConfig = %+v; want %+v", s.rateConfig, wantRateConfig) + } + s.mu.Unlock() + + // Verify client limiter. + lim := c.recvLim.Load() + if lim == nil { + t.Fatal("expected non-nil limiter") + } + verifyLimiter(t, lim, wantRateConfig) + }) + + t.Run("burst_is_at_least_minRateLimitTokenBucketSize", func(t *testing.T) { + s := New(key.NewNode(), t.Logf) + defer s.Close() + + f := writeConfig(t, `{"PerClientRateLimitBytesPerSec": 1250000, "PerClientRateBurstBytes": 10}`) + if err := s.LoadAndApplyRateConfig(f); err != nil { + t.Fatalf("LoadAndApplyRateConfig: %v", err) + } + + s.mu.Lock() + gotClientBurst := s.rateConfig.PerClientRateBurstBytes + s.mu.Unlock() + if gotClientBurst != minRateLimitTokenBucketSize { + t.Errorf("client burst = %d; want %d", gotClientBurst, minRateLimitTokenBucketSize) + } + }) + + t.Run("reload_disables_limiting", func(t *testing.T) { + s := New(key.NewNode(), t.Logf) + defer s.Close() + + f := writeConfig(t, `{"PerClientRateLimitBytesPerSec": 1250000, "PerClientRateBurstBytes": 2500000}`) + if err := s.LoadAndApplyRateConfig(f); err != nil { + t.Fatal(err) + } + s.mu.Lock() + if reflect.DeepEqual(s.rateConfig, RateConfig{}) { + t.Error("s.rateConfig is zero val; want nonzero rates") + } + s.mu.Unlock() + + if err := os.WriteFile(f, []byte(`{}`), 0644); err != nil { + t.Fatal(err) + } + if err := s.LoadAndApplyRateConfig(f); err != nil { + t.Fatal(err) + } + + s.mu.Lock() + if !reflect.DeepEqual(s.rateConfig, RateConfig{}) { + t.Errorf("s.rateConfig = %+v; want %+v", s.rateConfig, RateConfig{}) + } + s.mu.Unlock() + }) + + t.Run("propagates_errors", func(t *testing.T) { + s := New(key.NewNode(), t.Logf) + defer s.Close() + + if err := s.LoadAndApplyRateConfig(filepath.Join(t.TempDir(), "nonexistent.json")); err == nil { + t.Fatal("expected error") + } + }) +} + func BenchmarkSenderCardinalityOverhead(b *testing.B) { hll := hyperloglog.New() sender := key.NewNode().Public() diff --git a/derp/xdp/xdp_linux.go b/derp/xdp/xdp_linux.go index 5d22716be4f16..7ab23bd2e9eed 100644 --- a/derp/xdp/xdp_linux.go +++ b/derp/xdp/xdp_linux.go @@ -62,8 +62,7 @@ func NewSTUNServer(config *STUNServerConfig, opts ...STUNServerOption) (*STUNSer objs := new(bpfObjects) err = loadBpfObjects(objs, nil) if err != nil { - var ve *ebpf.VerifierError - if config.FullVerifierErr && errors.As(err, &ve) { + if ve, ok := errors.AsType[*ebpf.VerifierError](err); config.FullVerifierErr && ok { err = fmt.Errorf("verifier error: %+v", ve) } return nil, fmt.Errorf("error loading XDP program: %w", err) diff --git a/derp/xdp/xdp_linux_test.go b/derp/xdp/xdp_linux_test.go index 5c75a69ff3fbb..d8de2bf62d24d 100644 --- a/derp/xdp/xdp_linux_test.go +++ b/derp/xdp/xdp_linux_test.go @@ -18,6 +18,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/checksum" "gvisor.dev/gvisor/pkg/tcpip/header" "tailscale.com/net/stun" + "tailscale.com/tstest" ) type xdpAction uint32 @@ -271,6 +272,7 @@ func getIPv6STUNBindingResp() []byte { } func TestXDP(t *testing.T) { + tstest.RequireRoot(t) ipv4STUNBindingReqTX := getIPv4STUNBindingReq(nil) ipv6STUNBindingReqTX := getIPv6STUNBindingReq(nil) @@ -447,7 +449,7 @@ func TestXDP(t *testing.T) { wantMetrics map[bpfCountersKey]uint64 }{ { - name: "ipv4 STUN Binding Request Drop STUN", + name: "ipv4-STUN-Binding-Request-Drop-STUN", dropSTUN: true, packetIn: ipv4STUNBindingReqTX, wantCode: xdpActionDrop, @@ -466,7 +468,7 @@ func TestXDP(t *testing.T) { }, }, { - name: "ipv6 STUN Binding Request Drop STUN", + name: "ipv6-STUN-Binding-Request-Drop-STUN", dropSTUN: true, packetIn: ipv6STUNBindingReqTX, wantCode: xdpActionDrop, @@ -485,7 +487,7 @@ func TestXDP(t *testing.T) { }, }, { - name: "ipv4 STUN Binding Request TX", + name: "ipv4-STUN-Binding-Request-TX", packetIn: ipv4STUNBindingReqTX, wantCode: xdpActionTX, wantPacketOut: getIPv4STUNBindingResp(), @@ -503,7 +505,7 @@ func TestXDP(t *testing.T) { }, }, { - name: "ipv6 STUN Binding Request TX", + name: "ipv6-STUN-Binding-Request-TX", packetIn: ipv6STUNBindingReqTX, wantCode: xdpActionTX, wantPacketOut: getIPv6STUNBindingResp(), @@ -521,7 +523,7 @@ func TestXDP(t *testing.T) { }, }, { - name: "ipv4 STUN Binding Request invalid ip csum PASS", + name: "ipv4-STUN-Binding-Request-invalid-ip-csum-PASS", packetIn: ipv4STUNBindingReqIPCsumPass, wantCode: xdpActionPass, wantPacketOut: ipv4STUNBindingReqIPCsumPass, @@ -539,7 +541,7 @@ func TestXDP(t *testing.T) { }, }, { - name: "ipv4 STUN Binding Request ihl PASS", + name: "ipv4-STUN-Binding-Request-ihl-PASS", packetIn: ipv4STUNBindingReqIHLPass, wantCode: xdpActionPass, wantPacketOut: ipv4STUNBindingReqIHLPass, @@ -557,7 +559,7 @@ func TestXDP(t *testing.T) { }, }, { - name: "ipv4 STUN Binding Request ip version PASS", + name: "ipv4-STUN-Binding-Request-ip-version-PASS", packetIn: ipv4STUNBindingReqIPVerPass, wantCode: xdpActionPass, wantPacketOut: ipv4STUNBindingReqIPVerPass, @@ -575,7 +577,7 @@ func TestXDP(t *testing.T) { }, }, { - name: "ipv4 STUN Binding Request ip proto PASS", + name: "ipv4-STUN-Binding-Request-ip-proto-PASS", packetIn: ipv4STUNBindingReqIPProtoPass, wantCode: xdpActionPass, wantPacketOut: ipv4STUNBindingReqIPProtoPass, @@ -593,7 +595,7 @@ func TestXDP(t *testing.T) { }, }, { - name: "ipv4 STUN Binding Request frag offset PASS", + name: "ipv4-STUN-Binding-Request-frag-offset-PASS", packetIn: ipv4STUNBindingReqFragOffsetPass, wantCode: xdpActionPass, wantPacketOut: ipv4STUNBindingReqFragOffsetPass, @@ -611,7 +613,7 @@ func TestXDP(t *testing.T) { }, }, { - name: "ipv4 STUN Binding Request flags mf PASS", + name: "ipv4-STUN-Binding-Request-flags-mf-PASS", packetIn: ipv4STUNBindingReqFlagsMFPass, wantCode: xdpActionPass, wantPacketOut: ipv4STUNBindingReqFlagsMFPass, @@ -629,7 +631,7 @@ func TestXDP(t *testing.T) { }, }, { - name: "ipv4 STUN Binding Request tot len PASS", + name: "ipv4-STUN-Binding-Request-tot-len-PASS", packetIn: ipv4STUNBindingReqTotLenPass, wantCode: xdpActionPass, wantPacketOut: ipv4STUNBindingReqTotLenPass, @@ -647,7 +649,7 @@ func TestXDP(t *testing.T) { }, }, { - name: "ipv6 STUN Binding Request ip version PASS", + name: "ipv6-STUN-Binding-Request-ip-version-PASS", packetIn: ipv6STUNBindingReqIPVerPass, wantCode: xdpActionPass, wantPacketOut: ipv6STUNBindingReqIPVerPass, @@ -665,7 +667,7 @@ func TestXDP(t *testing.T) { }, }, { - name: "ipv6 STUN Binding Request next hdr PASS", + name: "ipv6-STUN-Binding-Request-next-hdr-PASS", packetIn: ipv6STUNBindingReqNextHdrPass, wantCode: xdpActionPass, wantPacketOut: ipv6STUNBindingReqNextHdrPass, @@ -683,7 +685,7 @@ func TestXDP(t *testing.T) { }, }, { - name: "ipv6 STUN Binding Request payload len PASS", + name: "ipv6-STUN-Binding-Request-payload-len-PASS", packetIn: ipv6STUNBindingReqPayloadLenPass, wantCode: xdpActionPass, wantPacketOut: ipv6STUNBindingReqPayloadLenPass, @@ -701,7 +703,7 @@ func TestXDP(t *testing.T) { }, }, { - name: "ipv4 STUN Binding Request UDP csum PASS", + name: "ipv4-STUN-Binding-Request-UDP-csum-PASS", packetIn: ipv4STUNBindingReqUDPCsumPass, wantCode: xdpActionPass, wantPacketOut: ipv4STUNBindingReqUDPCsumPass, @@ -719,7 +721,7 @@ func TestXDP(t *testing.T) { }, }, { - name: "ipv6 STUN Binding Request UDP csum PASS", + name: "ipv6-STUN-Binding-Request-UDP-csum-PASS", packetIn: ipv6STUNBindingReqUDPCsumPass, wantCode: xdpActionPass, wantPacketOut: ipv6STUNBindingReqUDPCsumPass, @@ -737,7 +739,7 @@ func TestXDP(t *testing.T) { }, }, { - name: "ipv4 STUN Binding Request STUN type PASS", + name: "ipv4-STUN-Binding-Request-STUN-type-PASS", packetIn: ipv4STUNBindingReqSTUNTypePass, wantCode: xdpActionPass, wantPacketOut: ipv4STUNBindingReqSTUNTypePass, @@ -755,7 +757,7 @@ func TestXDP(t *testing.T) { }, }, { - name: "ipv6 STUN Binding Request STUN type PASS", + name: "ipv6-STUN-Binding-Request-STUN-type-PASS", packetIn: ipv6STUNBindingReqSTUNTypePass, wantCode: xdpActionPass, wantPacketOut: ipv6STUNBindingReqSTUNTypePass, @@ -773,7 +775,7 @@ func TestXDP(t *testing.T) { }, }, { - name: "ipv4 STUN Binding Request STUN magic PASS", + name: "ipv4-STUN-Binding-Request-STUN-magic-PASS", packetIn: ipv4STUNBindingReqSTUNMagicPass, wantCode: xdpActionPass, wantPacketOut: ipv4STUNBindingReqSTUNMagicPass, @@ -791,7 +793,7 @@ func TestXDP(t *testing.T) { }, }, { - name: "ipv6 STUN Binding Request STUN magic PASS", + name: "ipv6-STUN-Binding-Request-STUN-magic-PASS", packetIn: ipv6STUNBindingReqSTUNMagicPass, wantCode: xdpActionPass, wantPacketOut: ipv6STUNBindingReqSTUNMagicPass, @@ -809,7 +811,7 @@ func TestXDP(t *testing.T) { }, }, { - name: "ipv4 STUN Binding Request STUN attrs len PASS", + name: "ipv4-STUN-Binding-Request-STUN-attrs-len-PASS", packetIn: ipv4STUNBindingReqSTUNAttrsLenPass, wantCode: xdpActionPass, wantPacketOut: ipv4STUNBindingReqSTUNAttrsLenPass, @@ -827,7 +829,7 @@ func TestXDP(t *testing.T) { }, }, { - name: "ipv6 STUN Binding Request STUN attrs len PASS", + name: "ipv6-STUN-Binding-Request-STUN-attrs-len-PASS", packetIn: ipv6STUNBindingReqSTUNAttrsLenPass, wantCode: xdpActionPass, wantPacketOut: ipv6STUNBindingReqSTUNAttrsLenPass, @@ -845,7 +847,7 @@ func TestXDP(t *testing.T) { }, }, { - name: "ipv4 STUN Binding Request STUN SW val PASS", + name: "ipv4-STUN-Binding-Request-STUN-SW-val-PASS", packetIn: ipv4STUNBindingReqSTUNSWValPass, wantCode: xdpActionPass, wantPacketOut: ipv4STUNBindingReqSTUNSWValPass, @@ -863,7 +865,7 @@ func TestXDP(t *testing.T) { }, }, { - name: "ipv6 STUN Binding Request STUN SW val PASS", + name: "ipv6-STUN-Binding-Request-STUN-SW-val-PASS", packetIn: ipv6STUNBindingReqSTUNSWValPass, wantCode: xdpActionPass, wantPacketOut: ipv6STUNBindingReqSTUNSWValPass, @@ -881,7 +883,7 @@ func TestXDP(t *testing.T) { }, }, { - name: "ipv4 STUN Binding Request STUN first attr PASS", + name: "ipv4-STUN-Binding-Request-STUN-first-attr-PASS", packetIn: ipv4STUNBindingReqSTUNFirstAttrPass, wantCode: xdpActionPass, wantPacketOut: ipv4STUNBindingReqSTUNFirstAttrPass, @@ -899,7 +901,7 @@ func TestXDP(t *testing.T) { }, }, { - name: "ipv6 STUN Binding Request STUN first attr PASS", + name: "ipv6-STUN-Binding-Request-STUN-first-attr-PASS", packetIn: ipv6STUNBindingReqSTUNFirstAttrPass, wantCode: xdpActionPass, wantPacketOut: ipv6STUNBindingReqSTUNFirstAttrPass, @@ -917,7 +919,7 @@ func TestXDP(t *testing.T) { }, }, { - name: "ipv4 UDP zero csum TX", + name: "ipv4-UDP-zero-csum-TX", packetIn: ipv4STUNBindingReqUDPZeroCsumTx, wantCode: xdpActionTX, wantPacketOut: getIPv4STUNBindingResp(), @@ -935,7 +937,7 @@ func TestXDP(t *testing.T) { }, }, { - name: "ipv6 UDP zero csum PASS", + name: "ipv6-UDP-zero-csum-PASS", packetIn: ipv6STUNBindingReqUDPZeroCsumPass, wantCode: xdpActionPass, wantPacketOut: ipv6STUNBindingReqUDPZeroCsumPass, @@ -957,10 +959,6 @@ func TestXDP(t *testing.T) { server, err := NewSTUNServer(&STUNServerConfig{DeviceName: "fake", DstPort: defaultSTUNPort}, &noAttachOption{}) if err != nil { - if errors.Is(err, unix.EPERM) { - // TODO(jwhited): get this running - t.Skip("skipping due to EPERM error; test requires elevated privileges") - } t.Fatalf("error constructing STUN server: %v", err) } defer server.Close() diff --git a/disco/disco.go b/disco/disco.go index 2147529d175d4..19a172412ce70 100644 --- a/disco/disco.go +++ b/disco/disco.go @@ -307,8 +307,7 @@ func MessageSummary(m Message) string { // BindUDPRelayHandshakeState represents the state of the 3-way bind handshake // between UDP relay client and UDP relay server. Its potential values include // those for both participants, UDP relay client and UDP relay server. A UDP -// relay server implementation can be found in net/udprelay. This is currently -// considered experimental. +// relay server implementation can be found in net/udprelay. type BindUDPRelayHandshakeState int const ( @@ -475,7 +474,7 @@ const allocateUDPRelayEndpointRequestLen = key.DiscoPublicRawLen*2 + // ClientDi func (m *AllocateUDPRelayEndpointRequest) AppendMarshal(b []byte) []byte { ret, p := appendMsgHeader(b, TypeAllocateUDPRelayEndpointRequest, v0, allocateUDPRelayEndpointRequestLen) - for i := 0; i < len(m.ClientDisco); i++ { + for i := range len(m.ClientDisco) { disco := m.ClientDisco[i].AppendTo(nil) copy(p, disco) p = p[key.DiscoPublicRawLen:] @@ -492,7 +491,7 @@ func parseAllocateUDPRelayEndpointRequest(ver uint8, p []byte) (m *AllocateUDPRe if len(p) < allocateUDPRelayEndpointRequestLen { return m, errShort } - for i := 0; i < len(m.ClientDisco); i++ { + for i := range len(m.ClientDisco) { m.ClientDisco[i] = key.DiscoPublicFromRaw32(mem.B(p[:key.DiscoPublicRawLen])) p = p[key.DiscoPublicRawLen:] } @@ -565,7 +564,7 @@ func (m *UDPRelayEndpoint) encode(b []byte) { disco := m.ServerDisco.AppendTo(nil) copy(b, disco) b = b[key.DiscoPublicRawLen:] - for i := 0; i < len(m.ClientDisco); i++ { + for i := range len(m.ClientDisco) { disco = m.ClientDisco[i].AppendTo(nil) copy(b, disco) b = b[key.DiscoPublicRawLen:] @@ -594,7 +593,7 @@ func (m *UDPRelayEndpoint) decode(b []byte) error { } m.ServerDisco = key.DiscoPublicFromRaw32(mem.B(b[:key.DiscoPublicRawLen])) b = b[key.DiscoPublicRawLen:] - for i := 0; i < len(m.ClientDisco); i++ { + for i := range len(m.ClientDisco) { m.ClientDisco[i] = key.DiscoPublicFromRaw32(mem.B(b[:key.DiscoPublicRawLen])) b = b[key.DiscoPublicRawLen:] } diff --git a/docs/commit-messages.md b/docs/commit-messages.md index b617e1fadd425..57047a803414e 100644 --- a/docs/commit-messages.md +++ b/docs/commit-messages.md @@ -43,13 +43,14 @@ Notably, for the subject (the first line of description): Examples: - | Good Example | notes | - | ------- | --- | - | `foo/bar: fix memory leak` | | - | `foo/bar: bump deps` | | - | `foo/bar: temporarily restrict access` | adverbs are okay | - | `foo/bar: implement new UI design` | | - | `control/{foo,bar}: optimize bar` | feel free to use {foo,bar} for common subpackages| + | Good Example | notes | + |----------------------------------------|---------------------------------------------------| + | `foo/bar: fix memory leak` | | + | `foo/bar: bump deps` | | + | `foo/bar: temporarily restrict access` | adverbs are okay | + | `foo/bar: implement new UI design` | | + | `control/{foo,bar}: optimize bar` | feel free to use {foo,bar} for common subpackages | + | `control,docs: document control usage` | multiple top-level packages are affected | | Bad Example | notes | | ------- | --- | @@ -73,7 +74,7 @@ For the body (the rest of the description): - blank line after the subject (first) line - the text should be wrapped to ~76 characters (to appease git viewing tools, mainly), unless you really need longer lines (e.g. for ASCII art, tables, or long links) - there must be a `Fixes` or `Updates` line for all non-cleanup commits linking to a tracking bug. This goes after the body with a blank newline separating the two. A pull request may be referenced rather than a tracking bug (using the same format, e.g. `Updates #12345`), though a bug is generally preferred. [Cleanup commits](#is-it-a-cleanup) can use `Updates #cleanup` instead of an issue. -- `Change-Id` lines should ideally be included in commits in the `corp` repo and are more optional in `tailscale/tailscale`. You can configure Git to do this for you by running `./tool/go run misc/install-git-hooks.go` from the root of the corp repo. This was originally a Gerrit thing and we don't use Gerrit, but it lets us tooling track commits as they're cherry-picked between branches. Also, tools like [git-cleanup](https://github.com/bradfitz/gitutil) use it to clean up your old local branches once they're merged upstream. +- `Change-Id` lines should be included in commits. You can configure Git to do this for you by running `./tool/go run misc/install-git-hooks.go` from the root of the repo. This was originally a Gerrit thing and we don't use Gerrit, but it lets tooling track commits as they're cherry-picked between branches. Also, tools like [git-cleanup](https://github.com/bradfitz/gitutil) use it to clean up your old local branches once they're merged upstream. - we don't use Markdown in commit messages. (Accidental Markdown like bulleted lists or even headings is fine, but not links) - we require `Signed-off-by` lines in public repos (such as `tailscale/tailscale`). Add them using `git commit --signoff` or `git commit -s` for short. You can use them in private repos but do not have to. - when moving code between repos, include the repository name, and git hash that it was moved from/to, so it is easier to trace history/blame. diff --git a/docs/embed.go b/docs/embed.go new file mode 100644 index 0000000000000..420a4d42362a3 --- /dev/null +++ b/docs/embed.go @@ -0,0 +1,12 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// Package docs embeds certain docs, making them available for other packages. +package docs + +import _ "embed" + +// CommitMessages is the contents of commit-messages.md. +// +//go:embed commit-messages.md +var CommitMessages string diff --git a/docs/webhooks/example.go b/docs/webhooks/example.go index 53ec1c8b74b52..d93c425d2c3d2 100644 --- a/docs/webhooks/example.go +++ b/docs/webhooks/example.go @@ -87,7 +87,7 @@ func verifyWebhookSignature(req *http.Request, secret string) (events []event, e return nil, err } mac := hmac.New(sha256.New, []byte(secret)) - mac.Write([]byte(fmt.Sprint(timestamp.Unix()))) + mac.Write(fmt.Append(nil, timestamp.Unix())) mac.Write([]byte(".")) mac.Write(b) want := hex.EncodeToString(mac.Sum(nil)) @@ -120,8 +120,8 @@ func parseSignatureHeader(header string) (timestamp time.Time, signatures map[st } signatures = make(map[string][]string) - pairs := strings.Split(header, ",") - for _, pair := range pairs { + pairs := strings.SplitSeq(header, ",") + for pair := range pairs { parts := strings.Split(pair, "=") if len(parts) != 2 { return time.Time{}, nil, errNotSigned diff --git a/docs/windows/policy/en-US/tailscale.adml b/docs/windows/policy/en-US/tailscale.adml index a0be5e8314a2b..225c1a8d40dc5 100644 --- a/docs/windows/policy/en-US/tailscale.adml +++ b/docs/windows/policy/en-US/tailscale.adml @@ -178,6 +178,16 @@ If you disable this policy, then Automatically Install Updates is always disable If you do not configure this policy, then Automatically Install Updates depends on what is selected in the Preferences submenu. See https://tailscale.com/kb/1067/update#auto-updates for more details.]]> + Automatically check for updates + Run Tailscale as an Exit Node never + + + + + always + + + never + + diff --git a/drive/driveimpl/compositedav/rewriting.go b/drive/driveimpl/compositedav/rewriting.go index 47f020461b77d..1f0a69d75978e 100644 --- a/drive/driveimpl/compositedav/rewriting.go +++ b/drive/driveimpl/compositedav/rewriting.go @@ -63,7 +63,7 @@ func (h *Handler) delegateRewriting(w http.ResponseWriter, r *http.Request, path // Fixup paths to add the requested path as a prefix, escaped for inclusion in XML. pp := shared.EscapeForXML(shared.Join(pathComponents[0:mpl]...)) - b := responseHrefRegex.ReplaceAll(bw.buf.Bytes(), []byte(fmt.Sprintf("$1%s/$3", pp))) + b := responseHrefRegex.ReplaceAll(bw.buf.Bytes(), fmt.Appendf(nil, "$1%s/$3", pp)) return bw.status, b } diff --git a/drive/driveimpl/dirfs/dirfs_test.go b/drive/driveimpl/dirfs/dirfs_test.go index c5f3aed3a99f0..559160716bdba 100644 --- a/drive/driveimpl/dirfs/dirfs_test.go +++ b/drive/driveimpl/dirfs/dirfs_test.go @@ -29,7 +29,7 @@ func TestStat(t *testing.T) { err error }{ { - label: "root folder", + label: "root-folder", name: "", expected: &shared.StaticFileInfo{ Named: "", @@ -40,7 +40,7 @@ func TestStat(t *testing.T) { }, }, { - label: "static root folder", + label: "static-root-folder", name: "/domain", expected: &shared.StaticFileInfo{ Named: "domain", @@ -73,7 +73,7 @@ func TestStat(t *testing.T) { }, }, { - label: "non-existent remote", + label: "non-existent-remote", name: "remote3", err: os.ErrNotExist, }, @@ -108,7 +108,7 @@ func TestListDir(t *testing.T) { err error }{ { - label: "root folder", + label: "root-folder", name: "", expected: []fs.FileInfo{ &shared.StaticFileInfo{ @@ -121,7 +121,7 @@ func TestListDir(t *testing.T) { }, }, { - label: "static root folder", + label: "static-root-folder", name: "/domain", expected: []fs.FileInfo{ &shared.StaticFileInfo{ @@ -189,19 +189,19 @@ func TestMkdir(t *testing.T) { err error }{ { - label: "attempt to create root folder", + label: "create-root-folder", name: "/", }, { - label: "attempt to create static root folder", + label: "create-static-root-folder", name: "/domain", }, { - label: "attempt to create remote", + label: "create-remote", name: "/domain/remote1", }, { - label: "attempt to create non-existent remote", + label: "create-non-existent-remote", name: "/domain/remote3", err: os.ErrPermission, }, @@ -231,7 +231,7 @@ func TestRemoveAll(t *testing.T) { err error }{ { - label: "attempt to remove root folder", + label: "remove-root-folder", name: "/", err: os.ErrPermission, }, @@ -258,7 +258,7 @@ func TestRename(t *testing.T) { err error }{ { - label: "attempt to move root folder", + label: "move-root-folder", oldName: "/", newName: "/domain/remote2/copy.txt", err: os.ErrPermission, diff --git a/drive/driveimpl/drive_test.go b/drive/driveimpl/drive_test.go index db7bfe60bde19..185ae2a9c2118 100644 --- a/drive/driveimpl/drive_test.go +++ b/drive/driveimpl/drive_test.go @@ -156,27 +156,27 @@ func TestMissingPaths(t *testing.T) { wantStatus int }{ { - name: "empty path", + name: "empty-path", path: "", wantStatus: http.StatusForbidden, }, { - name: "single slash", + name: "single-slash", path: "/", wantStatus: http.StatusForbidden, }, { - name: "only token", + name: "only-token", path: "/" + secretToken, wantStatus: http.StatusBadRequest, }, { - name: "token with trailing slash", + name: "token-trailing-slash", path: "/" + secretToken + "/", wantStatus: http.StatusBadRequest, }, { - name: "token and invalid share", + name: "token-invalid-share", path: "/" + secretToken + "/nonexistentshare", wantStatus: http.StatusNotFound, }, @@ -239,7 +239,7 @@ func TestLOCK(t *testing.T) { } u := fmt.Sprintf("http://%s/%s/%s/%s/%s", - s.local.l.Addr(), + s.local.ln.Addr(), url.PathEscape(domain), url.PathEscape(remote1), url.PathEscape(share11), @@ -365,7 +365,7 @@ func TestUNLOCK(t *testing.T) { } u := fmt.Sprintf("http://%s/%s/%s/%s/%s", - s.local.l.Addr(), + s.local.ln.Addr(), url.PathEscape(domain), url.PathEscape(remote1), url.PathEscape(share11), @@ -428,12 +428,12 @@ func TestUNLOCK(t *testing.T) { } type local struct { - l net.Listener + ln net.Listener fs *FileSystemForLocal } type remote struct { - l net.Listener + ln net.Listener fs *FileSystemForRemote fileServer *FileServer shares map[string]string @@ -487,7 +487,7 @@ func newSystem(t *testing.T) *system { client.SetTransport(&http.Transport{DisableKeepAlives: true}) s := &system{ t: t, - local: &local{l: ln, fs: fs}, + local: &local{ln: ln, fs: fs}, client: client, remotes: make(map[string]*remote), } @@ -510,7 +510,7 @@ func (s *system) addRemote(name string) string { s.t.Logf("FileServer for %v listening at %s", name, fileServer.Addr()) r := &remote{ - l: ln, + ln: ln, fileServer: fileServer, fs: NewFileSystemForRemote(log.Printf), shares: make(map[string]string), @@ -524,7 +524,7 @@ func (s *system) addRemote(name string) string { for name, r := range s.remotes { remotes = append(remotes, &drive.Remote{ Name: name, - URL: func() string { return fmt.Sprintf("http://%s", r.l.Addr()) }, + URL: func() string { return fmt.Sprintf("http://%s", r.ln.Addr()) }, }) } s.local.fs.SetRemotes( @@ -683,7 +683,7 @@ func (s *system) stop() { s.t.Fatalf("failed to Close fs: %s", err) } - err = s.local.l.Close() + err = s.local.ln.Close() if err != nil { s.t.Fatalf("failed to Close listener: %s", err) } @@ -694,7 +694,7 @@ func (s *system) stop() { s.t.Fatalf("failed to Close remote fs: %s", err) } - err = r.l.Close() + err = r.ln.Close() if err != nil { s.t.Fatalf("failed to Close remote listener: %s", err) } diff --git a/drive/driveimpl/remote_impl.go b/drive/driveimpl/remote_impl.go index df27ba71627df..0ff27dc643efe 100644 --- a/drive/driveimpl/remote_impl.go +++ b/drive/driveimpl/remote_impl.go @@ -415,7 +415,7 @@ var writeMethods = map[string]bool{ "DELETE": true, } -// canSudo checks wether we can sudo -u the configured executable as the +// canSudo checks whether we can sudo -u the configured executable as the // configured user by attempting to call the executable with the '-h' flag to // print help. func (s *userServer) canSudo() bool { diff --git a/envknob/envknob.go b/envknob/envknob.go index 2b1461f11f308..73a0da7005041 100644 --- a/envknob/envknob.go +++ b/envknob/envknob.go @@ -405,6 +405,9 @@ func SSHIgnoreTailnetPolicy() bool { return Bool("TS_DEBUG_SSH_IGNORE_TAILNET_PO // TKASkipSignatureCheck reports whether to skip node-key signature checking for development. func TKASkipSignatureCheck() bool { return Bool("TS_UNSAFE_SKIP_NKS_VERIFICATION") } +// AssumeNetworkUp reports whether to assume network connectivity for development. +func AssumeNetworkUp() bool { return Bool("TS_ASSUME_NETWORK_UP_FOR_TEST") } + // App returns the tailscale app type of this instance, if set via // TS_INTERNAL_APP env var. TS_INTERNAL_APP can be used to set app type for // components that wrap tailscaled, such as containerboot. App type is intended diff --git a/envknob/logknob/logknob.go b/envknob/logknob/logknob.go deleted file mode 100644 index bc6e8c3627077..0000000000000 --- a/envknob/logknob/logknob.go +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -// Package logknob provides a helpful wrapper that allows enabling logging -// based on either an envknob or other methods of enablement. -package logknob - -import ( - "sync/atomic" - - "tailscale.com/envknob" - "tailscale.com/tailcfg" - "tailscale.com/types/logger" -) - -// TODO(andrew-d): should we have a package-global registry of logknobs? It -// would allow us to update from a netmap in a central location, which might be -// reason enough to do it... - -// LogKnob allows configuring verbose logging, with multiple ways to enable. It -// supports enabling logging via envknob, via atomic boolean (for use in e.g. -// c2n log level changes), and via capabilities from a NetMap (so users can -// enable logging via the ACL JSON). -type LogKnob struct { - capName tailcfg.NodeCapability - cap atomic.Bool - env func() bool - manual atomic.Bool -} - -// NewLogKnob creates a new LogKnob, with the provided environment variable -// name and/or NetMap capability. -func NewLogKnob(env string, cap tailcfg.NodeCapability) *LogKnob { - if env == "" && cap == "" { - panic("must provide either an environment variable or capability") - } - - lk := &LogKnob{ - capName: cap, - } - if env != "" { - lk.env = envknob.RegisterBool(env) - } else { - lk.env = func() bool { return false } - } - return lk -} - -// Set will cause logs to be printed when called with Set(true). When called -// with Set(false), logs will not be printed due to an earlier call of -// Set(true), but may be printed due to either the envknob and/or capability of -// this LogKnob. -func (lk *LogKnob) Set(v bool) { - lk.manual.Store(v) -} - -// NetMap is an interface for the parts of netmap.NetworkMap that we care -// about; we use this rather than a concrete type to avoid a circular -// dependency. -type NetMap interface { - HasSelfCapability(tailcfg.NodeCapability) bool -} - -// UpdateFromNetMap will enable logging if the SelfNode in the provided NetMap -// contains the capability provided for this LogKnob. -func (lk *LogKnob) UpdateFromNetMap(nm NetMap) { - if lk.capName == "" { - return - } - lk.cap.Store(nm.HasSelfCapability(lk.capName)) -} - -// Do will call log with the provided format and arguments if any of the -// configured methods for enabling logging are true. -func (lk *LogKnob) Do(log logger.Logf, format string, args ...any) { - if lk.shouldLog() { - log(format, args...) - } -} - -func (lk *LogKnob) shouldLog() bool { - return lk.manual.Load() || lk.env() || lk.cap.Load() -} diff --git a/envknob/logknob/logknob_test.go b/envknob/logknob/logknob_test.go deleted file mode 100644 index 9e7ab8aef6368..0000000000000 --- a/envknob/logknob/logknob_test.go +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -package logknob - -import ( - "bytes" - "fmt" - "testing" - - "tailscale.com/envknob" - "tailscale.com/tailcfg" - "tailscale.com/types/netmap" - "tailscale.com/util/set" -) - -var testKnob = NewLogKnob( - "TS_TEST_LOGKNOB", - "https://tailscale.com/cap/testing", -) - -// Static type assertion for our interface type. -var _ NetMap = &netmap.NetworkMap{} - -func TestLogKnob(t *testing.T) { - t.Run("Default", func(t *testing.T) { - if testKnob.shouldLog() { - t.Errorf("expected default shouldLog()=false") - } - assertNoLogs(t) - }) - t.Run("Manual", func(t *testing.T) { - t.Cleanup(func() { testKnob.Set(false) }) - - assertNoLogs(t) - testKnob.Set(true) - if !testKnob.shouldLog() { - t.Errorf("expected shouldLog()=true") - } - assertLogs(t) - }) - t.Run("Env", func(t *testing.T) { - t.Cleanup(func() { - envknob.Setenv("TS_TEST_LOGKNOB", "") - }) - - assertNoLogs(t) - if testKnob.shouldLog() { - t.Errorf("expected default shouldLog()=false") - } - - envknob.Setenv("TS_TEST_LOGKNOB", "true") - if !testKnob.shouldLog() { - t.Errorf("expected shouldLog()=true") - } - assertLogs(t) - }) - t.Run("NetMap", func(t *testing.T) { - t.Cleanup(func() { testKnob.cap.Store(false) }) - - assertNoLogs(t) - if testKnob.shouldLog() { - t.Errorf("expected default shouldLog()=false") - } - - testKnob.UpdateFromNetMap(&netmap.NetworkMap{ - AllCaps: set.Of(tailcfg.NodeCapability("https://tailscale.com/cap/testing")), - }) - if !testKnob.shouldLog() { - t.Errorf("expected shouldLog()=true") - } - assertLogs(t) - }) -} - -func assertLogs(t *testing.T) { - var buf bytes.Buffer - logf := func(format string, args ...any) { - fmt.Fprintf(&buf, format, args...) - } - - testKnob.Do(logf, "hello %s", "world") - const want = "hello world" - if got := buf.String(); got != want { - t.Errorf("got %q, want %q", got, want) - } -} - -func assertNoLogs(t *testing.T) { - var buf bytes.Buffer - logf := func(format string, args ...any) { - fmt.Fprintf(&buf, format, args...) - } - - testKnob.Do(logf, "hello %s", "world") - if got := buf.String(); got != "" { - t.Errorf("expected no logs, but got: %q", got) - } -} diff --git a/feature/buildfeatures/feature_ipnbus_disabled.go b/feature/buildfeatures/feature_ipnbus_disabled.go new file mode 100644 index 0000000000000..b71dbda62e34e --- /dev/null +++ b/feature/buildfeatures/feature_ipnbus_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_ipnbus + +package buildfeatures + +// HasIPNBus is whether the binary was built with support for modular feature "IPN notification bus (watch-ipn-bus) support, used by GUIs, debugging, and nicer 'tailscale up' support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_ipnbus" build tag. +// It's a const so it can be used for dead code elimination. +const HasIPNBus = false diff --git a/feature/buildfeatures/feature_ipnbus_enabled.go b/feature/buildfeatures/feature_ipnbus_enabled.go new file mode 100644 index 0000000000000..74d9547293f66 --- /dev/null +++ b/feature/buildfeatures/feature_ipnbus_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_ipnbus + +package buildfeatures + +// HasIPNBus is whether the binary was built with support for modular feature "IPN notification bus (watch-ipn-bus) support, used by GUIs, debugging, and nicer 'tailscale up' support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_ipnbus" build tag. +// It's a const so it can be used for dead code elimination. +const HasIPNBus = true diff --git a/feature/buildfeatures/feature_lazywg_disabled.go b/feature/buildfeatures/feature_tundevstats_disabled.go similarity index 54% rename from feature/buildfeatures/feature_lazywg_disabled.go rename to feature/buildfeatures/feature_tundevstats_disabled.go index af1ad388c03a7..e78816138a419 100644 --- a/feature/buildfeatures/feature_lazywg_disabled.go +++ b/feature/buildfeatures/feature_tundevstats_disabled.go @@ -3,11 +3,11 @@ // Code generated by gen.go; DO NOT EDIT. -//go:build ts_omit_lazywg +//go:build ts_omit_tundevstats package buildfeatures -// HasLazyWG is whether the binary was built with support for modular feature "Lazy WireGuard configuration for memory-constrained devices with large netmaps". -// Specifically, it's whether the binary was NOT built with the "ts_omit_lazywg" build tag. +// HasTUNDevStats is whether the binary was built with support for modular feature "Poll TUN device statistics (Linux only)". +// Specifically, it's whether the binary was NOT built with the "ts_omit_tundevstats" build tag. // It's a const so it can be used for dead code elimination. -const HasLazyWG = false +const HasTUNDevStats = false diff --git a/feature/buildfeatures/feature_lazywg_enabled.go b/feature/buildfeatures/feature_tundevstats_enabled.go similarity index 54% rename from feature/buildfeatures/feature_lazywg_enabled.go rename to feature/buildfeatures/feature_tundevstats_enabled.go index f2d6a10f81580..ffa4ebee7d31a 100644 --- a/feature/buildfeatures/feature_lazywg_enabled.go +++ b/feature/buildfeatures/feature_tundevstats_enabled.go @@ -3,11 +3,11 @@ // Code generated by gen.go; DO NOT EDIT. -//go:build !ts_omit_lazywg +//go:build !ts_omit_tundevstats package buildfeatures -// HasLazyWG is whether the binary was built with support for modular feature "Lazy WireGuard configuration for memory-constrained devices with large netmaps". -// Specifically, it's whether the binary was NOT built with the "ts_omit_lazywg" build tag. +// HasTUNDevStats is whether the binary was built with support for modular feature "Poll TUN device statistics (Linux only)". +// Specifically, it's whether the binary was NOT built with the "ts_omit_tundevstats" build tag. // It's a const so it can be used for dead code elimination. -const HasLazyWG = true +const HasTUNDevStats = true diff --git a/feature/clientupdate/clientupdate.go b/feature/clientupdate/clientupdate.go index d47d048156046..999dd79200de0 100644 --- a/feature/clientupdate/clientupdate.go +++ b/feature/clientupdate/clientupdate.go @@ -163,6 +163,7 @@ func (e *extension) DoSelfUpdate() { }) if err != nil { e.pushSelfUpdateProgress(ipnstate.NewUpdateProgress(ipnstate.UpdateFailed, err.Error())) + return } err = up.Update() if err != nil { diff --git a/feature/condlite/expvar/omit.go b/feature/condlite/expvar/omit.go index b5481695c9947..188de2af2436f 100644 --- a/feature/condlite/expvar/omit.go +++ b/feature/condlite/expvar/omit.go @@ -3,7 +3,6 @@ //go:build ts_omit_debug && ts_omit_clientmetrics && ts_omit_usermetrics -// excluding the package from builds. package expvar type Int int64 diff --git a/feature/condregister/maybe_desktop_sessions.go b/feature/condregister/maybe_desktop_sessions.go new file mode 100644 index 0000000000000..bb93a8bcbff7c --- /dev/null +++ b/feature/condregister/maybe_desktop_sessions.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build windows && !ts_omit_desktop_sessions + +package condregister + +import _ "tailscale.com/ipn/desktop" diff --git a/feature/condregister/maybe_tailnetlock.go b/feature/condregister/maybe_tailnetlock.go new file mode 100644 index 0000000000000..80a3dffe31aee --- /dev/null +++ b/feature/condregister/maybe_tailnetlock.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_tailnetlock + +package condregister + +import _ "tailscale.com/feature/tailnetlock" diff --git a/feature/condregister/maybe_tundevstats.go b/feature/condregister/maybe_tundevstats.go new file mode 100644 index 0000000000000..f678a8c9c2f77 --- /dev/null +++ b/feature/condregister/maybe_tundevstats.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux && !ts_omit_tundevstats + +package condregister + +import _ "tailscale.com/feature/tundevstats" diff --git a/feature/conn25/conn25.go b/feature/conn25/conn25.go index 05f087e21df46..9bdda1cebf904 100644 --- a/feature/conn25/conn25.go +++ b/feature/conn25/conn25.go @@ -8,44 +8,87 @@ package conn25 import ( + "bytes" + "context" "encoding/json" "errors" + "fmt" + "io" "net/http" "net/netip" + "slices" + "strings" "sync" + "sync/atomic" + "time" "go4.org/netipx" "golang.org/x/net/dns/dnsmessage" + "tailscale.com/appc" + "tailscale.com/envknob" "tailscale.com/feature" + "tailscale.com/ipn" "tailscale.com/ipn/ipnext" "tailscale.com/ipn/ipnlocal" - "tailscale.com/net/dns" + "tailscale.com/net/packet" + "tailscale.com/net/tsaddr" + "tailscale.com/net/tstun" "tailscale.com/tailcfg" + "tailscale.com/tstime" "tailscale.com/types/appctype" + "tailscale.com/types/key" "tailscale.com/types/logger" + "tailscale.com/types/views" "tailscale.com/util/dnsname" "tailscale.com/util/mak" "tailscale.com/util/set" + "tailscale.com/util/testenv" + "tailscale.com/wgengine/filter" ) // featureName is the name of the feature implemented by this package. // It is also the [extension] name and the log prefix. const featureName = "conn25" +const maxBodyBytes = 1024 * 1024 + +// jsonDecode decodes all of a io.ReadCloser (eg an http.Request Body) into one pointer with best practices. +// It limits the size of bytes it will read. +// It either decodes all of the bytes into the pointer, or errors (unlike json.Decoder.Decode). +// It closes the ReadCloser after reading. +func jsonDecode(target any, rc io.ReadCloser) error { + defer rc.Close() + respBs, err := io.ReadAll(io.LimitReader(rc, maxBodyBytes+1)) + if err != nil { + return err + } + err = json.Unmarshal(respBs, &target) + return err +} + +func normalizeDNSName(name string) (dnsname.FQDN, error) { + // note that appconnector does this same thing, tsdns has its own custom lower casing + // it might be good to unify in a function in dnsname package. + return dnsname.ToFQDN(strings.ToLower(name)) +} + func init() { feature.Register(featureName) - newExtension := func(logf logger.Logf, sb ipnext.SafeBackend) (ipnext.Extension, error) { - e := &extension{ + ipnext.RegisterExtension(featureName, func(logf logger.Logf, sb ipnext.SafeBackend) (ipnext.Extension, error) { + return &extension{ conn25: newConn25(logger.WithPrefix(logf, "conn25: ")), backend: sb, - } - return e, nil - } - ipnext.RegisterExtension(featureName, newExtension) + }, nil + }) ipnlocal.RegisterPeerAPIHandler("/v0/connector/transit-ip", handleConnectorTransitIP) } func handleConnectorTransitIP(h ipnlocal.PeerAPIHandler, w http.ResponseWriter, r *http.Request) { + // TODO(tailscale/corp#39033): Remove for alpha release. + if !envknob.UseWIPCode() && !testenv.InTest() { + w.WriteHeader(http.StatusNotImplemented) + return + } e, ok := ipnlocal.GetExt[*extension](h.LocalBackend()) if !ok { http.Error(w, "miswired", http.StatusInternalServerError) @@ -60,8 +103,8 @@ type extension struct { conn25 *Conn25 // safe for concurrent access and only set at creation backend ipnext.SafeBackend // safe for concurrent access and only set at creation - mu sync.Mutex // protects the fields below - isDNSHookRegistered bool + host ipnext.Host // set in Init, read-only after + ctxCancel context.CancelCauseFunc // cancels sendLoop goroutine } // Name implements [ipnext.Extension]. @@ -71,17 +114,168 @@ func (e *extension) Name() string { // Init implements [ipnext.Extension]. func (e *extension) Init(host ipnext.Host) error { - host.Hooks().OnSelfChange.Add(e.onSelfChange) + // TODO(tailscale/corp#39033): Remove for alpha release. + if !envknob.UseWIPCode() && !testenv.InTest() { + return ipnext.SkipExtension + } + + if e.ctxCancel != nil { + return nil + } + e.host = host + + dph := newDatapathHandler(e.conn25, e.conn25.logf) + if err := e.installHooks(dph); err != nil { + return err + } + profile, prefs := e.host.Profiles().CurrentProfileState() + e.profileStateChange(profile, prefs, false) + + ctx, cancel := context.WithCancelCause(context.Background()) + e.ctxCancel = cancel + go e.sendLoop(ctx) return nil } +func (e *extension) installHooks(dph *datapathHandler) error { + // Make sure we can access the DNS manager and the system tun. + dnsManager, ok := e.backend.Sys().DNSManager.GetOK() + if !ok { + return errors.New("could not access system dns manager") + } + tun, ok := e.backend.Sys().Tun.GetOK() + if !ok { + return errors.New("could not access system tun") + } + + // Set up the DNS manager to rewrite responses for app domains + // to answer with Magic IPs. + dnsManager.SetQueryResponseMapper(func(bs []byte) []byte { + if !e.conn25.isConfigured() { + return bs + } + return e.conn25.mapDNSResponse(bs) + }) + + // Intercept packets from the tun device and from WireGuard + // to perform DNAT and SNAT. + tun.PreFilterPacketOutboundToWireGuardAppConnectorIntercept = func(p *packet.Parsed, _ *tstun.Wrapper) filter.Response { + if !e.conn25.isConfigured() { + return filter.Accept + } + return dph.HandlePacketFromTunDevice(p) + } + tun.PostFilterPacketInboundFromWireGuardAppConnector = func(p *packet.Parsed, _ *tstun.Wrapper) filter.Response { + if !e.conn25.isConfigured() { + return filter.Accept + } + return dph.HandlePacketFromWireGuard(p) + } + + // Manage how we react to changes to the current node, + // including property changes (e.g. HostInfo, Capabilities, CapMap). + e.host.Hooks().OnSelfChange.Add(e.onSelfChange) + + // Manage how we react profile state changes, which include + // prefs changes. + e.host.Hooks().ProfileStateChange.Add(e.profileStateChange) + + // Allow the client to send packets with Transit IP destinations + // in the link-local space. + e.host.Hooks().Filter.LinkLocalAllowHooks.Add(func(p packet.Parsed) (bool, string) { + if !e.conn25.isConfigured() { + return false, "" + } + return e.conn25.client.linkLocalAllow(p) + }) + + // Allow the connector to receive packets with Transit IP destinations + // in the link-local space. + e.host.Hooks().Filter.LinkLocalAllowHooks.Add(func(p packet.Parsed) (bool, string) { + if !e.conn25.isConfigured() { + return false, "" + } + return e.conn25.connector.packetFilterAllow(p) + }) + + // Allow the connector to receive packets with Transit IP destinations + // that are not "local" to it, and that it does not advertise. + e.host.Hooks().Filter.IngressAllowHooks.Add(func(p packet.Parsed) (bool, string) { + if !e.conn25.isConfigured() { + return false, "" + } + return e.conn25.connector.packetFilterAllow(p) + }) + + // Give the client the Magic IP range to install on the OS. + e.host.Hooks().ExtraRouterConfigRoutes.Set(func() views.Slice[netip.Prefix] { + if !e.conn25.isConfigured() { + return views.Slice[netip.Prefix]{} + } + return e.getMagicRange() + }) + + // Tell WireGuard what Transit IPs belong to which connector peers. + e.host.Hooks().ExtraWireGuardAllowedIPs.Set(func(k key.NodePublic) views.Slice[netip.Prefix] { + if !e.conn25.isConfigured() { + return views.Slice[netip.Prefix]{} + } + return e.extraWireGuardAllowedIPs(k) + }) + + return nil +} + +// ClientTransitIPForMagicIP implements [IPMapper]. +func (c *Conn25) ClientTransitIPForMagicIP(m netip.Addr) (netip.Addr, error) { + if addr, ok := c.client.transitIPForMagicIP(m); ok { + return addr, nil + } + cfg, ok := c.getConfig() + if !ok { + return netip.Addr{}, nil + } + if !cfg.ipSets.v4Magic.Contains(m) && !cfg.ipSets.v6Magic.Contains(m) { + return netip.Addr{}, nil + } + return netip.Addr{}, ErrUnmappedMagicIP +} + +// ConnectorRealIPForTransitIPConnection implements [IPMapper]. +func (c *Conn25) ConnectorRealIPForTransitIPConnection(src, transit netip.Addr) (netip.Addr, error) { + if addr, ok := c.connector.realIPForTransitIPConnection(src, transit); ok { + return addr, nil + } + cfg, ok := c.getConfig() + if !ok { + return netip.Addr{}, nil + } + if !cfg.ipSets.v4Transit.Contains(transit) && !cfg.ipSets.v6Transit.Contains(transit) { + return netip.Addr{}, nil + } + return netip.Addr{}, ErrUnmappedSrcAndTransitIP +} + +func (e *extension) getMagicRange() views.Slice[netip.Prefix] { + cfg, ok := e.conn25.getConfig() + if !ok { + return views.Slice[netip.Prefix]{} + } + return views.SliceOf(slices.Concat(cfg.ipSets.v4Magic.Prefixes(), cfg.ipSets.v6Magic.Prefixes())) +} + // Shutdown implements [ipnlocal.Extension]. func (e *extension) Shutdown() error { + if e.ctxCancel != nil { + e.ctxCancel(errors.New("extension shutdown")) + } + if e.conn25 != nil { + close(e.conn25.client.addrsCh) + } return nil } func (e *extension) handleConnectorTransitIP(h ipnlocal.PeerAPIHandler, w http.ResponseWriter, r *http.Request) { - const maxBodyBytes = 1024 * 1024 defer r.Body.Close() if r.Method != "POST" { http.Error(w, "Method should be POST", http.StatusMethodNotAllowed) @@ -93,7 +287,7 @@ func (e *extension) handleConnectorTransitIP(h ipnlocal.PeerAPIHandler, w http.R http.Error(w, "Error decoding JSON", http.StatusBadRequest) return } - resp := e.conn25.handleConnectorTransitIPRequest(h.Peer().ID(), req) + resp := e.conn25.handleConnectorTransitIPRequest(h.Peer(), req) bs, err := json.Marshal(resp) if err != nil { http.Error(w, "Error encoding JSON", http.StatusInternalServerError) @@ -102,56 +296,25 @@ func (e *extension) handleConnectorTransitIP(h ipnlocal.PeerAPIHandler, w http.R w.Write(bs) } +// onSelfChange implements the [ipnext.Hooks.OnSelfChange] hook. func (e *extension) onSelfChange(selfNode tailcfg.NodeView) { - err := e.conn25.reconfig(selfNode) + cfg, err := configFromNodeView(selfNode) if err != nil { - e.conn25.client.logf("error during Reconfig onSelfChange: %v", err) + e.conn25.logf("error generating config from self node view: %v", err) return } - - if e.conn25.isConfigured() { - err = e.registerDNSHook() - } else { - err = e.unregisterDNSHook() - } - if err != nil { - e.conn25.client.logf("error managing DNS hook onSelfChange: %v", err) - } + e.conn25.reconfig(cfg) } -func (e *extension) registerDNSHook() error { - e.mu.Lock() - defer e.mu.Unlock() - if e.isDNSHookRegistered { - return nil - } - err := e.setDNSHookLocked(e.conn25.mapDNSResponse) - if err == nil { - e.isDNSHookRegistered = true - } - return err +// profileStateChange implements the [ipnext.Hooks.ProfileStateChange] hook. +func (e *extension) profileStateChange(loginProfile ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) { + // TODO(mzb): Handle node changes. Wipe out all config? + // We'll need to look at the ordering of this hook and onSelfChange. + e.conn25.prefsAdvertiseConnector.Store(prefs.AppConnector().Advertise) } -func (e *extension) unregisterDNSHook() error { - e.mu.Lock() - defer e.mu.Unlock() - if !e.isDNSHookRegistered { - return nil - } - err := e.setDNSHookLocked(nil) - if err == nil { - e.isDNSHookRegistered = false - } - return err -} - -func (e *extension) setDNSHookLocked(fx dns.ResponseMapper) error { - dnsManager, ok := e.backend.Sys().DNSManager.GetOK() - if !ok || dnsManager == nil { - return errors.New("couldn't get DNSManager from sys") - } - dnsManager.SetQueryResponseMapper(fx) - return nil +func (e *extension) extraWireGuardAllowedIPs(k key.NodePublic) views.Slice[netip.Prefix] { + return e.conn25.client.extraWireGuardAllowedIPs(k) } type appAddr struct { @@ -161,19 +324,41 @@ type appAddr struct { // Conn25 holds state for routing traffic for a domain via a connector. type Conn25 struct { - client *client - connector *connector + config atomic.Pointer[config] + prefsAdvertiseConnector atomic.Bool + logf logger.Logf + client *client + connector *connector +} + +func (c *Conn25) getConfig() (*config, bool) { + cfg := c.config.Load() + return cfg, cfg.isConfigured } func (c *Conn25) isConfigured() bool { - return c.client.isConfigured() + _, ok := c.getConfig() + return ok } func newConn25(logf logger.Logf) *Conn25 { c := &Conn25{ - client: &client{logf: logf}, + logf: logf, connector: &connector{logf: logf}, } + c.config.Store(&config{}) // initialize with empty to avoid nil checks + c.client = &client{ + logf: logf, + addrsCh: make(chan addrs, 64), + assignments: addrAssignments{clock: tstime.StdClock{}}, + getIPSets: func() ipSets { + cfg, ok := c.getConfig() + if !ok { + return emptyIPSets() + } + return cfg.ipSets + }, + } return c } @@ -185,71 +370,114 @@ func ipSetFromIPRanges(rs []netipx.IPRange) (*netipx.IPSet, error) { return b.IPSet() } -func (c *Conn25) reconfig(selfNode tailcfg.NodeView) error { - cfg, err := configFromNodeView(selfNode) - if err != nil { - return err - } - if err := c.client.reconfig(cfg); err != nil { - return err - } - if err := c.connector.reconfig(cfg); err != nil { - return err - } - return nil -} - -// mapDNSResponse parses and inspects the DNS response, and uses the -// contents to assign addresses for connecting. It does not yet modify -// the response. -func (c *Conn25) mapDNSResponse(buf []byte) []byte { - return c.client.mapDNSResponse(buf) +func (c *Conn25) reconfig(cfg *config) { + c.config.Store(cfg) + c.client.reconfig() } const dupeTransitIPMessage = "Duplicate transit address in ConnectorTransitIPRequest" +const noMatchingPeerIPFamilyMessage = "No peer IP found with matching IP family" +const addrFamilyMismatchMessage = "Transit and Destination addresses must have matching IP family" +const unknownAppNameMessage = "The App name in the request does not match a configured App" -// handleConnectorTransitIPRequest creates a ConnectorTransitIPResponse in response to a ConnectorTransitIPRequest. -// It updates the connectors mapping of TransitIP->DestinationIP per peer (tailcfg.NodeID). -// If a peer has stored this mapping in the connector Conn25 will route traffic to TransitIPs to DestinationIPs for that peer. -func (c *Conn25) handleConnectorTransitIPRequest(nid tailcfg.NodeID, ctipr ConnectorTransitIPRequest) ConnectorTransitIPResponse { +// handleConnectorTransitIPRequest creates a ConnectorTransitIPResponse in response +// to a ConnectorTransitIPRequest. It updates the connectors mapping of +// TransitIP->DestinationIP per peer (using the Peer's IP that matches the address +// family of the transitIP). If a peer has stored this mapping in the connector, +// Conn25 will route traffic to TransitIPs to DestinationIPs for that peer. +func (c *Conn25) handleConnectorTransitIPRequest(n tailcfg.NodeView, ctipr ConnectorTransitIPRequest) ConnectorTransitIPResponse { resp := ConnectorTransitIPResponse{} + cfg, ok := c.getConfig() + if !ok { + // TODO(mzb): If this node is no longer configured at the + // the time of this call, perhaps there should be a top-level + // error, instead of error-per-TransitIP? + for range ctipr.TransitIPs { + resp.TransitIPs = append(resp.TransitIPs, TransitIPResponse{ + Code: UnknownAppName, + Message: unknownAppNameMessage, + }) + } + return resp + } + + var peerIPv4, peerIPv6 netip.Addr + for _, ip := range n.Addresses().All() { + if !ip.IsSingleIP() || !tsaddr.IsTailscaleIP(ip.Addr()) { + continue + } + if ip.Addr().Is4() && !peerIPv4.IsValid() { + peerIPv4 = ip.Addr() + } else if ip.Addr().Is6() && !peerIPv6.IsValid() { + peerIPv6 = ip.Addr() + } + } + seen := map[netip.Addr]bool{} for _, each := range ctipr.TransitIPs { if seen[each.TransitIP] { resp.TransitIPs = append(resp.TransitIPs, TransitIPResponse{ - Code: OtherFailure, + Code: DuplicateTransitIP, Message: dupeTransitIPMessage, }) + c.logf("[Unexpected] peer attempt to map a transit IP reused a transitIP: node: %s, IP: %v", + n.StableID(), each.TransitIP) continue } - tipresp := c.connector.handleTransitIPRequest(nid, each) + + if _, ok := cfg.appsByName[each.App]; !ok { + resp.TransitIPs = append(resp.TransitIPs, TransitIPResponse{ + Code: UnknownAppName, + Message: unknownAppNameMessage, + }) + c.logf("[Unexpected] peer attempt to map a transit IP referenced unknown app: node: %s, app: %q", + n.StableID(), each.App) + continue + } + tipresp := c.connector.handleTransitIPRequest(n, peerIPv4, peerIPv6, each) seen[each.TransitIP] = true resp.TransitIPs = append(resp.TransitIPs, tipresp) } return resp } -func (s *connector) handleTransitIPRequest(nid tailcfg.NodeID, tipr TransitIPRequest) TransitIPResponse { - s.mu.Lock() - defer s.mu.Unlock() - if s.transitIPs == nil { - s.transitIPs = make(map[tailcfg.NodeID]map[netip.Addr]appAddr) +func (c *connector) handleTransitIPRequest(n tailcfg.NodeView, peerV4 netip.Addr, peerV6 netip.Addr, tipr TransitIPRequest) TransitIPResponse { + if tipr.TransitIP.Is4() != tipr.DestinationIP.Is4() { + c.logf("[Unexpected] peer attempt to map a transit IP to dest IP did not have matching families: node: %s, tIPv4: %v dIPv4: %v", + n.StableID(), tipr.TransitIP.Is4(), tipr.DestinationIP.Is4()) + return TransitIPResponse{Code: AddrFamilyMismatch, Message: addrFamilyMismatchMessage} + } + + // Datapath lookups only have access to the peer IP, and that will match the family + // of the transit IP, so we need to store v4 and v6 mappings separately. + var peerAddr netip.Addr + if tipr.TransitIP.Is4() { + peerAddr = peerV4 + } else { + peerAddr = peerV6 + } + + // If we couldn't find a matching family, return an error. + if !peerAddr.IsValid() { + c.logf("[Unexpected] peer attempt to map a transit IP did not have a matching address family: node: %s, IPv4: %v", + n.StableID(), tipr.TransitIP.Is4()) + return TransitIPResponse{NoMatchingPeerIPFamily, noMatchingPeerIPFamilyMessage} } - peerMap, ok := s.transitIPs[nid] + + c.mu.Lock() + defer c.mu.Unlock() + if c.transitIPs == nil { + c.transitIPs = make(map[netip.Addr]map[netip.Addr]appAddr) + } + peerMap, ok := c.transitIPs[peerAddr] if !ok { peerMap = make(map[netip.Addr]appAddr) - s.transitIPs[nid] = peerMap + c.transitIPs[peerAddr] = peerMap } peerMap[tipr.TransitIP] = appAddr{addr: tipr.DestinationIP, app: tipr.App} return TransitIPResponse{} } -func (s *connector) transitIPTarget(nid tailcfg.NodeID, tip netip.Addr) netip.Addr { - s.mu.Lock() - defer s.mu.Unlock() - return s.transitIPs[nid][tip].addr -} - // TransitIPRequest details a single TransitIP allocation request from a client to a // connector. type TransitIPRequest struct { @@ -281,8 +509,24 @@ const ( OK TransitIPResponseCode = 0 // OtherFailure indicates that the mapping failed for a reason that does not have - // another relevant [TransitIPResponsecode]. + // another relevant [TransitIPResponseCode]. OtherFailure TransitIPResponseCode = 1 + + // DuplicateTransitIP indicates that the same transit address appeared more than + // once in a [ConnectorTransitIPRequest]. + DuplicateTransitIP TransitIPResponseCode = 2 + + // NoMatchingPeerIPFamily indicates that the peer did not have an associated + // IP with the same family as transit IP being registered. + NoMatchingPeerIPFamily = 3 + + // AddrFamilyMismatch indicates that the transit IP and destination IP addresses + // do not belong to the same IP family. + AddrFamilyMismatch = 4 + + // UnknownAppName indicates that the connector is not configured to handle requests + // for the App name that was specified in the request. + UnknownAppName = 5 ) // TransitIPResponse is the response to a TransitIPRequest @@ -304,48 +548,105 @@ type ConnectorTransitIPResponse struct { const AppConnectorsExperimentalAttrName = "tailscale.com/app-connectors-experimental" -// config holds the config from the policy and lookups derived from that. +// ipSets wraps all the IPSets the config needs. +type ipSets struct { + v4Transit *netipx.IPSet + v4Magic *netipx.IPSet + v6Transit *netipx.IPSet + v6Magic *netipx.IPSet +} + +func emptyIPSets() ipSets { + return ipSets{ + v4Transit: &netipx.IPSet{}, + v4Magic: &netipx.IPSet{}, + v6Transit: &netipx.IPSet{}, + v6Magic: &netipx.IPSet{}, + } +} + +// config holds the config derived from the self node view, +// which includes the policy. // config is not safe for concurrent use. type config struct { - isConfigured bool - apps []appctype.Conn25Attr - appsByDomain map[dnsname.FQDN][]string - selfRoutedDomains set.Set[dnsname.FQDN] + isConfigured bool + apps []appctype.Conn25Attr + appsByName map[string]appctype.Conn25Attr + appNamesByDomain map[dnsname.FQDN][]string + appNamesByWCDomain map[dnsname.FQDN][]string + selfAppNames set.Set[string] + ipSets ipSets } -func configFromNodeView(n tailcfg.NodeView) (config, error) { +func configFromNodeView(n tailcfg.NodeView) (*config, error) { apps, err := tailcfg.UnmarshalNodeCapViewJSON[appctype.Conn25Attr](n.CapMap(), AppConnectorsExperimentalAttrName) if err != nil { - return config{}, err + return &config{}, err } if len(apps) == 0 { - return config{}, nil + return &config{}, nil } selfTags := set.SetOf(n.Tags().AsSlice()) - cfg := config{ - isConfigured: true, - apps: apps, - appsByDomain: map[dnsname.FQDN][]string{}, - selfRoutedDomains: set.Set[dnsname.FQDN]{}, + cfg := &config{ + isConfigured: true, + apps: apps, + appsByName: map[string]appctype.Conn25Attr{}, + appNamesByDomain: map[dnsname.FQDN][]string{}, + appNamesByWCDomain: map[dnsname.FQDN][]string{}, + selfAppNames: set.Set[string]{}, + ipSets: emptyIPSets(), } for _, app := range apps { - selfMatchesTags := false - for _, tag := range app.Connectors { - if selfTags.Contains(tag) { - selfMatchesTags = true - break - } - } + normalizedDomains := set.Set[dnsname.FQDN]{} + normalizedWCDomains := set.Set[dnsname.FQDN]{} for _, d := range app.Domains { - fqdn, err := dnsname.ToFQDN(d) + domain, isWild := strings.CutPrefix(d, "*.") + fqdn, err := normalizeDNSName(domain) if err != nil { - return config{}, err + return &config{}, err } - mak.Set(&cfg.appsByDomain, fqdn, append(cfg.appsByDomain[fqdn], app.Name)) - if selfMatchesTags { - cfg.selfRoutedDomains.Add(fqdn) + if isWild && !normalizedWCDomains.Contains(fqdn) { + normalizedWCDomains.Add(fqdn) + mak.Set(&cfg.appNamesByWCDomain, fqdn, append(cfg.appNamesByWCDomain[fqdn], app.Name)) + } else if !isWild && !normalizedDomains.Contains(fqdn) { + normalizedDomains.Add(fqdn) + mak.Set(&cfg.appNamesByDomain, fqdn, append(cfg.appNamesByDomain[fqdn], app.Name)) } } + mak.Set(&cfg.appsByName, app.Name, app) + if slices.ContainsFunc(app.Connectors, selfTags.Contains) { + cfg.selfAppNames.Add(app.Name) + } + + } + + // TODO(fran) 2026-03-18 we don't yet have a proper way to communicate the + // global IP pool config. For now just take it from the first app. + if len(apps) != 0 { + app := apps[0] + v4Mipp, err := ipSetFromIPRanges(app.V4MagicIPPool) + if err != nil { + return &config{}, err + } + v4Tipp, err := ipSetFromIPRanges(app.V4TransitIPPool) + if err != nil { + return &config{}, err + } + v6Mipp, err := ipSetFromIPRanges(app.V6MagicIPPool) + if err != nil { + return &config{}, err + } + v6Tipp, err := ipSetFromIPRanges(app.V6TransitIPPool) + if err != nil { + return &config{}, err + } + ipSets := ipSets{ + v4Magic: v4Mipp, + v4Transit: v4Tipp, + v6Magic: v6Mipp, + v6Transit: v6Tipp, + } + cfg.ipSets = ipSets } return cfg, nil } @@ -355,111 +656,364 @@ func configFromNodeView(n tailcfg.NodeView) (config, error) { // connectors. // It's safe for concurrent use. type client struct { - logf logger.Logf + logf logger.Logf + addrsCh chan addrs + getIPSets func() ipSets - mu sync.Mutex // protects the fields below - magicIPPool *ippool - transitIPPool *ippool - assignments addrAssignments - config config + mu sync.Mutex // protects the fields below + v4MagicIPPool *ippool + v4TransitIPPool *ippool + v6MagicIPPool *ippool + v6TransitIPPool *ippool + assignments addrAssignments + byConnKey map[key.NodePublic]set.Set[netip.Prefix] } -func (c *client) isConfigured() bool { +// transitIPForMagicIP is part of the implementation of the IPMapper interface for dataflows lookups. +// See also [IPMapper.ClientTransitIPForMagicIP]. +func (c *client) transitIPForMagicIP(magicIP netip.Addr) (netip.Addr, bool) { c.mu.Lock() defer c.mu.Unlock() - return c.config.isConfigured + v, ok := c.assignments.lookupByMagicIP(magicIP) + if ok { + return v.transit, true + } + return netip.Addr{}, false } -func (c *client) reconfig(newCfg config) error { +// linkLocalAllow returns true if the provided packet with a link-local Dst address has a +// Dst that is one of our transit IPs, and false otherwise. +// Tailscale's wireguard filters drop link-local unicast packets (see [wgengine/filter/filter.go]) +// but conn25 uses link-local addresses for transit IPs. +// Let the filter know if this is one of our addresses and should be allowed. +func (c *client) linkLocalAllow(p packet.Parsed) (bool, string) { c.mu.Lock() defer c.mu.Unlock() - - c.config = newCfg - - // TODO(fran) this is not the correct way to manage the pools and changes to the pools. - // We probably want to: - // * check the pools haven't changed - // * reset the whole connector if the pools change? or just if they've changed to exclude - // addresses we have in use? - // * have config separate from the apps for this (rather than multiple potentially conflicting places) - // but this works while we are just getting started here. - for _, app := range c.config.apps { - if c.magicIPPool != nil { // just take the first config and never reconfig - break - } - if app.MagicIPPool == nil { - continue - } - mipp, err := ipSetFromIPRanges(app.MagicIPPool) - if err != nil { - return err - } - tipp, err := ipSetFromIPRanges(app.TransitIPPool) - if err != nil { - return err - } - c.magicIPPool = newIPPool(mipp) - c.transitIPPool = newIPPool(tipp) + ok := c.isKnownTransitIP(p.Dst.Addr()) + if ok { + return true, packetFilterAllowReason } - return nil + return false, "" } -func (c *client) isConnectorDomain(domain dnsname.FQDN) bool { +func (c *client) isKnownTransitIP(tip netip.Addr) bool { + _, ok := c.assignments.lookupByTransitIP(tip) + return ok +} + +func (c *client) reconfig() { c.mu.Lock() defer c.mu.Unlock() - appNames, ok := c.config.appsByDomain[domain] - return ok && len(appNames) > 0 + + ipSets := c.getIPSets() + + c.v4MagicIPPool = newIPPool(ipSets.v4Magic) + c.v4TransitIPPool = newIPPool(ipSets.v4Transit) + c.v6MagicIPPool = newIPPool(ipSets.v6Magic) + c.v6TransitIPPool = newIPPool(ipSets.v6Transit) +} + +// getAppsForConnectorDomain returns the slice of app names which match the +// provided domain. Apps which match the domain exactly are preferred, +// otherwise the list of apps comes from the wildcard domain which matches +// the longest suffix of the specified domain. A nil or empty slice is returned +// if no match is found or if the list of matching apps would contain an app +// which is being handled by the self-node's connector. +func (cfg *config) getAppsForConnectorDomain(domain dnsname.FQDN, prefsAdvertiseConnector bool) []string { + // Lookup exact matches first + appNames := cfg.appNamesByDomain[domain] + if len(appNames) == 0 { + // No exact match, check wildcard domains + // We have made the decision that wildcards will match the base domain. + // So example.com will be a match for *.example.com, because we think that + // this is most likely what users will expect. + for d := domain; d != ""; d = d.Parent() { + if appNames = cfg.appNamesByWCDomain[d]; len(appNames) > 0 { + break + } + } + } + + // If we have a candidate match, make sure that no candidate app is pointing + // at a connector on the self-node. + if len(appNames) == 0 || (prefsAdvertiseConnector && slices.ContainsFunc(appNames, cfg.selfAppNames.Contains)) { + return nil + } + return appNames } // reserveAddresses tries to make an assignment of addrs from the address pools // for this domain+dst address, so that this client can use conn25 connectors. +// The name of the matching app is also provided, no validation is done to check whether or not +// the app name refers to a configured app. // It checks that this domain should be routed and that this client is not itself a connector for the domain // and generally if it is valid to make the assignment. -func (c *client) reserveAddresses(domain dnsname.FQDN, dst netip.Addr) (addrs, error) { +func (c *client) reserveAddresses(appName string, domain dnsname.FQDN, dst netip.Addr) (addrs, error) { + if !dst.IsValid() { + return addrs{}, errors.New("dst is not valid") + } c.mu.Lock() defer c.mu.Unlock() if existing, ok := c.assignments.lookupByDomainDst(domain, dst); ok { return existing, nil } - appNames, _ := c.config.appsByDomain[domain] - // only reserve for first app - app := appNames[0] - mip, err := c.magicIPPool.next() - if err != nil { - return addrs{}, err - } - tip, err := c.transitIPPool.next() - if err != nil { - return addrs{}, err + + var mip, tip netip.Addr + var err error + if dst.Is4() { + mip, err = c.v4MagicIPPool.next() + if err != nil { + return addrs{}, err + } + tip, err = c.v4TransitIPPool.next() + if err != nil { + return addrs{}, err + } + } else if dst.Is6() { + mip, err = c.v6MagicIPPool.next() + if err != nil { + return addrs{}, err + } + tip, err = c.v6TransitIPPool.next() + if err != nil { + return addrs{}, err + } + } else { + return addrs{}, errors.New("unexpected neither 4 nor 6") } as := addrs{ dst: dst, magic: mip, transit: tip, - app: app, + app: appName, domain: domain, } if err := c.assignments.insert(as); err != nil { return addrs{}, err } + err = c.enqueueAddressAssignment(as) + if err != nil { + return addrs{}, err + } return as, nil } -func (c *client) enqueueAddressAssignment(addrs addrs) { - // TODO(fran) 2026-02-03 asynchronously send peerapi req to connector to - // allocate these addresses for us. +func (c *client) addTransitIPForConnector(tip netip.Addr, conn tailcfg.NodeView) error { + if conn.Key().IsZero() { + return fmt.Errorf("node with stable ID %q does not have a key", conn.StableID()) + } + + c.mu.Lock() + defer c.mu.Unlock() + return c.insertTransitConnMapping(tip, conn.Key()) +} + +func (e *extension) sendLoop(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case as := <-e.conn25.client.addrsCh: + if err := e.handleAddressAssignment(ctx, as); err != nil { + e.conn25.logf("error handling transit IP assignment (app: %s, mip: %v, src: %v): %v", as.app, as.magic, as.dst, err) + } + } + } } -func (c *client) mapDNSResponse(buf []byte) []byte { +func (e *extension) handleAddressAssignment(ctx context.Context, as addrs) error { + conn, err := e.sendAddressAssignment(ctx, as) + if err != nil { + return err + } + err = e.conn25.client.addTransitIPForConnector(as.transit, conn) + if err != nil { + return err + } + + e.host.AuthReconfigAsync() + return nil +} + +func (c *client) enqueueAddressAssignment(addrs addrs) error { + select { + // TODO(fran) investigate the value of waiting for multiple addresses and sending them + // in one ConnectorTransitIPRequest + case c.addrsCh <- addrs: + return nil + default: + c.logf("address assignment queue full, dropping transit assignment for %v", addrs.domain) + return errors.New("queue full") + } +} + +func (c *client) extraWireGuardAllowedIPs(k key.NodePublic) views.Slice[netip.Prefix] { + c.mu.Lock() + defer c.mu.Unlock() + tips, ok := c.lookupTransitIPsByConnKey(k) + if !ok { + return views.Slice[netip.Prefix]{} + } + return views.SliceOf(tips) +} + +func makePeerAPIReq(ctx context.Context, httpClient *http.Client, urlBase string, as addrs) error { + url := urlBase + "/v0/connector/transit-ip" + + reqBody := ConnectorTransitIPRequest{ + TransitIPs: []TransitIPRequest{{ + TransitIP: as.transit, + DestinationIP: as.dst, + App: as.app, + }}, + } + bs, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("marshalling request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bs)) + if err != nil { + return fmt.Errorf("creating request: %w", err) + } + + resp, err := httpClient.Do(req) + if err != nil { + return fmt.Errorf("sending request: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("connector returned HTTP %d", resp.StatusCode) + } + + var respBody ConnectorTransitIPResponse + err = jsonDecode(&respBody, resp.Body) + if err != nil { + return fmt.Errorf("decoding response: %w", err) + } + + if len(respBody.TransitIPs) > 0 && respBody.TransitIPs[0].Code != OK { + return fmt.Errorf("connector error: %s", respBody.TransitIPs[0].Message) + } + return nil +} + +func (e *extension) sendAddressAssignment(ctx context.Context, as addrs) (tailcfg.NodeView, error) { + cfg, ok := e.conn25.getConfig() + if !ok { + return tailcfg.NodeView{}, errors.New("not configured") + } + app, ok := cfg.appsByName[as.app] + if !ok { + e.conn25.logf("App not found for app: %s (domain: %s)", as.app, as.domain) + return tailcfg.NodeView{}, errors.New("app not found") + } + + nb := e.host.NodeBackend() + peers := appc.PickConnector(nb, app) + var urlBase string + var conn tailcfg.NodeView + for _, p := range peers { + urlBase = nb.PeerAPIBase(p) + if urlBase != "" { + conn = p + break + } + } + if urlBase == "" { + return tailcfg.NodeView{}, errors.New("no connector peer found to handle address assignment") + } + client := e.backend.Sys().Dialer.Get().PeerAPIHTTPClient() + return conn, makePeerAPIReq(ctx, client, urlBase, as) +} + +type dnsResponseRewrite struct { + domain dnsname.FQDN + dst netip.Addr +} + +func makeServFail(logf logger.Logf, h dnsmessage.Header, q dnsmessage.Question) []byte { + h.Response = true + h.Authoritative = true + h.RCode = dnsmessage.RCodeServerFailure + b := dnsmessage.NewBuilder(nil, h) + err := b.StartQuestions() + if err != nil { + logf("error making servfail: %v", err) + return []byte{} + } + err = b.Question(q) + if err != nil { + logf("error making servfail: %v", err) + return []byte{} + } + bs, err := b.Finish() + if err != nil { + // If there's an error here there's a bug somewhere directly above. + // _possibly_ some kind of question that was parseable but not encodable?, + // otherwise we could panic. + logf("error making servfail: %v", err) + } + return bs +} + +// mapDNSResponse parses and inspects the DNS response. If the domain +// is determined to belong to app this node is client for, it assigns addresses +// for connecting and rewrites the response to contain Magic IPs. +func (c *Conn25) mapDNSResponse(buf []byte) []byte { var p dnsmessage.Parser - if _, err := p.Start(buf); err != nil { + hdr, err := p.Start(buf) + if err != nil { c.logf("error parsing dns response: %v", err) return buf } - if err := p.SkipAllQuestions(); err != nil { + questions, err := p.AllQuestions() + if err != nil { c.logf("error parsing dns response: %v", err) return buf } + // Any message we are interested in has one question (RFC 9619) + if len(questions) != 1 { + return buf + } + question := questions[0] + // The other Class types are not commonly used and supporting them hasn't been considered. + if question.Class != dnsmessage.ClassINET { + return buf + } + queriedDomain, err := normalizeDNSName(question.Name.String()) + if err != nil { + return buf + } + + cfg, ok := c.getConfig() + if !ok { + return buf + } + + appNames := cfg.getAppsForConnectorDomain(queriedDomain, c.prefsAdvertiseConnector.Load()) + if len(appNames) == 0 { + return buf + } + + // There is guaranteed to be at least one matching app, so just take the first one for now + appName := appNames[0] + + // Now we know this is a dns response we think we should rewrite, we're going to provide our response which + // currently means we will: + // * write the questions through as they are + // * not send through the additional section + // * provide our answers, or no answers if we don't handle those answers (possibly in the future we should write through answers for eg TypeTXT) + var answers []dnsResponseRewrite + if question.Type != dnsmessage.TypeA && question.Type != dnsmessage.TypeAAAA { + c.logf("mapping dns response for connector domain, unsupported type: %v", question.Type) + newBuf, err := c.client.rewriteDNSResponse(appName, hdr, questions, answers) + if err != nil { + c.logf("error writing empty response for unsupported type: %v", err) + return makeServFail(c.logf, hdr, question) + } + return newBuf + } for { h, err := p.AnswerHeader() if err == dnsmessage.ErrSectionDone { @@ -467,82 +1021,187 @@ func (c *client) mapDNSResponse(buf []byte) []byte { } if err != nil { c.logf("error parsing dns response: %v", err) - return buf + return makeServFail(c.logf, hdr, question) } - + // other classes are unsupported, and we checked the question was for ClassINET already if h.Class != dnsmessage.ClassINET { + c.logf("unexpected class for connector domain dns response: %v %v", queriedDomain, h.Class) if err := p.SkipAnswer(); err != nil { c.logf("error parsing dns response: %v", err) - return buf + return makeServFail(c.logf, hdr, question) } continue } - switch h.Type { - case dnsmessage.TypeA: - domain, err := dnsname.ToFQDN(h.Name.String()) - if err != nil { - c.logf("bad dnsname: %v", err) - return buf + case dnsmessage.TypeCNAME: + // An A record was asked for, and the answer is a CNAME, this answer will tell us which domain it's a CNAME for + // and a subsequent answer should tell us what the target domains address is (or possibly another CNAME). Drop + // this for now (2026-03-11) but in the near future we should collapse the CNAME chain and map to the ultimate + // destination address (see eg appc/{appconnector,observe}.go). + c.logf("not yet implemented CNAME answer: %v", queriedDomain) + if err := p.SkipAnswer(); err != nil { + c.logf("error parsing dns response: %v", err) + return makeServFail(c.logf, hdr, question) } - if !c.isConnectorDomain(domain) { + case dnsmessage.TypeA, dnsmessage.TypeAAAA: + if h.Type != question.Type { + // would not expect a v4 response to a v6 question or vice versa, don't add a rewrite for this. if err := p.SkipAnswer(); err != nil { c.logf("error parsing dns response: %v", err) - return buf + return makeServFail(c.logf, hdr, question) } continue } - r, err := p.AResource() + domain, err := normalizeDNSName(h.Name.String()) if err != nil { - c.logf("error parsing dns response: %v", err) - return buf + c.logf("bad dnsname: %v", err) + return makeServFail(c.logf, hdr, question) } - addrs, err := c.reserveAddresses(domain, netip.AddrFrom4(r.A)) - if err != nil { - c.logf("error assigning connector addresses: %v", err) - return buf + // answers should be for the domain that was queried + if domain != queriedDomain { + c.logf("unexpected domain for connector domain dns response: %v %v", queriedDomain, domain) + if err := p.SkipAnswer(); err != nil { + c.logf("error parsing dns response: %v", err) + return makeServFail(c.logf, hdr, question) + } + continue } - if !addrs.isValid() { - c.logf("assigned connector addresses unexpectedly empty: %v", err) - return buf + var dstAddr netip.Addr + if h.Type == dnsmessage.TypeA { + r, err := p.AResource() + if err != nil { + c.logf("error parsing dns response: %v", err) + return makeServFail(c.logf, hdr, question) + } + dstAddr = netip.AddrFrom4(r.A) + } else { + r, err := p.AAAAResource() + if err != nil { + c.logf("error parsing dns response: %v", err) + return makeServFail(c.logf, hdr, question) + } + dstAddr = netip.AddrFrom16(r.AAAA) } - c.enqueueAddressAssignment(addrs) + answers = append(answers, dnsResponseRewrite{domain: domain, dst: dstAddr}) default: + // we already checked the question was for a supported type, this answer is unexpected + c.logf("unexpected type for connector domain dns response: %v %v", queriedDomain, h.Type) if err := p.SkipAnswer(); err != nil { c.logf("error parsing dns response: %v", err) - return buf + return makeServFail(c.logf, hdr, question) } - continue } } + newBuf, err := c.client.rewriteDNSResponse(appName, hdr, questions, answers) + if err != nil { + c.logf("error rewriting dns response: %v", err) + return makeServFail(c.logf, hdr, question) + } + return newBuf +} - // TODO(fran) 2026-01-21 return a dns response with addresses - // swapped out for the magic IPs to make conn25 work. - return buf +func (c *client) rewriteDNSResponse(appName string, hdr dnsmessage.Header, questions []dnsmessage.Question, answers []dnsResponseRewrite) ([]byte, error) { + b := dnsmessage.NewBuilder(nil, hdr) + b.EnableCompression() + if err := b.StartQuestions(); err != nil { + return nil, err + } + for _, q := range questions { + if err := b.Question(q); err != nil { + return nil, err + } + } + if err := b.StartAnswers(); err != nil { + return nil, err + } + + // make an answer for each rewrite + for _, rw := range answers { + as, err := c.reserveAddresses(appName, rw.domain, rw.dst) + if err != nil { + return nil, err + } + if !as.isValid() { + return nil, errors.New("connector addresses empty") + } + name, err := dnsmessage.NewName(rw.domain.WithTrailingDot()) + if err != nil { + return nil, err + } + if rw.dst.Is4() { + rhdr := dnsmessage.ResourceHeader{Name: name, Type: dnsmessage.TypeA, Class: dnsmessage.ClassINET, TTL: 0} + if err := b.AResource(rhdr, dnsmessage.AResource{A: as.magic.As4()}); err != nil { + return nil, err + } + } else if rw.dst.Is6() { + rhdr := dnsmessage.ResourceHeader{Name: name, Type: dnsmessage.TypeAAAA, Class: dnsmessage.ClassINET, TTL: 0} + if err := b.AAAAResource(rhdr, dnsmessage.AAAAResource{AAAA: as.magic.As16()}); err != nil { + return nil, err + } + } else { + return nil, errors.New("unexpected neither 4 nor 6") + } + } + // We do _not_ include the additional section in our rewrite. (We don't want to include + // eg DNSSEC info, or other extra info like related records). + out, err := b.Finish() + if err != nil { + return nil, err + } + return out, nil } type connector struct { logf logger.Logf mu sync.Mutex // protects the fields below - // transitIPs is a map of connector client peer NodeID -> client transitIPs that we update as connector client peers instruct us to, and then use to route traffic to its destination on behalf of connector clients. - transitIPs map[tailcfg.NodeID]map[netip.Addr]appAddr - config config + // transitIPs is a map of connector client peer IP -> client transitIPs that we update as connector client peers instruct us to, and then use to route traffic to its destination on behalf of connector clients. + // Note that each peer could potentially have two maps: one for its IPv4 address, and one for its IPv6 address. The transit IPs map for a given peer IP will contain transit IPs of the same family as the peer's IP. + transitIPs map[netip.Addr]map[netip.Addr]appAddr } -func (s *connector) reconfig(newCfg config) error { - s.mu.Lock() - defer s.mu.Unlock() - s.config = newCfg - return nil +// realIPForTransitIPConnection is part of the implementation of the IPMapper interface for dataflows lookups. +// See also [IPMapper.ConnectorRealIPForTransitIPConnection]. +func (c *connector) realIPForTransitIPConnection(srcIP netip.Addr, transitIP netip.Addr) (netip.Addr, bool) { + c.mu.Lock() + defer c.mu.Unlock() + v, ok := c.lookupBySrcIPAndTransitIP(srcIP, transitIP) + if ok { + return v.addr, true + } + return netip.Addr{}, false +} + +const packetFilterAllowReason = "app connector transit IP" + +// packetFilterAllow returns true if the provided packet has a Src that maps to a peer +// that has a transit IP with us that is the packet Dst, and false otherwise. +func (c *connector) packetFilterAllow(p packet.Parsed) (bool, string) { + c.mu.Lock() + defer c.mu.Unlock() + _, ok := c.lookupBySrcIPAndTransitIP(p.Src.Addr(), p.Dst.Addr()) + if ok { + return true, packetFilterAllowReason + } + return false, "" +} + +func (c *connector) lookupBySrcIPAndTransitIP(srcIP, transitIP netip.Addr) (appAddr, bool) { + m, ok := c.transitIPs[srcIP] + if !ok || m == nil { + return appAddr{}, false + } + v, ok := m[transitIP] + return v, ok } type addrs struct { - dst netip.Addr - magic netip.Addr - transit netip.Addr - domain dnsname.FQDN - app string + dst netip.Addr + magic netip.Addr + transit netip.Addr + domain dnsname.FQDN + app string + expiresAt time.Time } func (c addrs) isValid() bool { @@ -557,28 +1216,105 @@ type domainDst struct { } // addrAssignments is the collection of addrs assigned by this client -// supporting lookup by magicip or domain+dst +// supporting lookup by magic IP, transit IP or domain+dst, or to lookup all +// transit IPs associated with a given connector (identified by its node key). +// byConnKey stores netip.Prefix versions of the transit IPs for use in the +// WireGuard hooks. type addrAssignments struct { byMagicIP map[netip.Addr]addrs + byTransitIP map[netip.Addr]addrs byDomainDst map[domainDst]addrs + clock tstime.Clock } +const defaultExpiry = 48 * time.Hour + func (a *addrAssignments) insert(as addrs) error { - // we likely will want to allow overwriting in the future when we - // have address expiry, but for now this should not happen - if _, ok := a.byMagicIP[as.magic]; ok { - return errors.New("byMagicIP key exists") + return a.insertWithExpiry(as, defaultExpiry) +} + +func (a *addrAssignments) insertWithExpiry(as addrs, d time.Duration) error { + if !as.expiresAt.IsZero() { + return errors.New("expiresAt already set") + } + now := a.clock.Now() + as.expiresAt = now.Add(d) + // we don't expect for addresses to be reused before expiry + if existing, ok := a.byMagicIP[as.magic]; ok { + if !existing.expiresAt.Before(now) { + return errors.New("byMagicIP key exists") + } } ddst := domainDst{domain: as.domain, dst: as.dst} - if _, ok := a.byDomainDst[ddst]; ok { - return errors.New("byDomainDst key exists") + if existing, ok := a.byDomainDst[ddst]; ok { + if !existing.expiresAt.Before(now) { + return errors.New("byDomainDst key exists") + } + } + if existing, ok := a.byTransitIP[as.transit]; ok { + if !existing.expiresAt.Before(now) { + return errors.New("byTransitIP key exists") + } } mak.Set(&a.byMagicIP, as.magic, as) + mak.Set(&a.byTransitIP, as.transit, as) mak.Set(&a.byDomainDst, ddst, as) return nil } func (a *addrAssignments) lookupByDomainDst(domain dnsname.FQDN, dst netip.Addr) (addrs, bool) { v, ok := a.byDomainDst[domainDst{domain: domain, dst: dst}] - return v, ok + if !ok || v.expiresAt.Before(a.clock.Now()) { + return addrs{}, false + } + return v, true +} + +func (a *addrAssignments) lookupByMagicIP(mip netip.Addr) (addrs, bool) { + v, ok := a.byMagicIP[mip] + if !ok || v.expiresAt.Before(a.clock.Now()) { + return addrs{}, false + } + return v, true +} + +func (a *addrAssignments) lookupByTransitIP(tip netip.Addr) (addrs, bool) { + v, ok := a.byTransitIP[tip] + if !ok || v.expiresAt.Before(a.clock.Now()) { + return addrs{}, false + } + return v, true +} + +// insertTransitConnMapping adds an entry to the byConnKey map +// for the provided transitIP (as a prefix). +// The provided transitIP must already be present in the byTransitIP map. +func (c *client) insertTransitConnMapping(tip netip.Addr, connKey key.NodePublic) error { + if _, ok := c.assignments.lookupByTransitIP(tip); !ok { + return errors.New("transit IP is not already known") + } + + ctips, ok := c.byConnKey[connKey] + tipp := netip.PrefixFrom(tip, tip.BitLen()) + if ok { + if ctips.Contains(tipp) { + return errors.New("byConnKey already contains transit") + } + } else { + ctips.Make() + mak.Set(&c.byConnKey, connKey, ctips) + } + ctips.Add(tipp) + return nil +} + +// lookupTransitIPsByConnKey returns a slice containing the transit IPs (as netipPrefix) +// associated with the given connector (identified by node key), or (nil, false) if there is no entry +// for the given key. +func (c *client) lookupTransitIPsByConnKey(k key.NodePublic) ([]netip.Prefix, bool) { + s, ok := c.byConnKey[k] + if !ok { + return nil, false + } + return s.Slice(), true } diff --git a/feature/conn25/conn25_test.go b/feature/conn25/conn25_test.go index d63e84e024738..2ed4190258ae1 100644 --- a/feature/conn25/conn25_test.go +++ b/feature/conn25/conn25_test.go @@ -5,17 +5,31 @@ package conn25 import ( "encoding/json" + "net/http" + "net/http/httptest" "net/netip" - "reflect" + "slices" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "go4.org/mem" "go4.org/netipx" "golang.org/x/net/dns/dnsmessage" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnext" + "tailscale.com/net/dns" + "tailscale.com/net/packet" + "tailscale.com/net/tsdial" + "tailscale.com/net/tstun" "tailscale.com/tailcfg" + "tailscale.com/tsd" + "tailscale.com/tstest" "tailscale.com/types/appctype" + "tailscale.com/types/key" "tailscale.com/types/logger" + "tailscale.com/types/opt" "tailscale.com/util/dnsname" "tailscale.com/util/must" "tailscale.com/util/set" @@ -31,217 +45,409 @@ func mustIPSetFromPrefix(s string) *netipx.IPSet { return set } -// TestHandleConnectorTransitIPRequestZeroLength tests that if sent a -// ConnectorTransitIPRequest with 0 TransitIPRequests, we respond with a -// ConnectorTransitIPResponse with 0 TransitIPResponses. -func TestHandleConnectorTransitIPRequestZeroLength(t *testing.T) { - c := newConn25(logger.Discard) - req := ConnectorTransitIPRequest{} - nid := tailcfg.NodeID(1) +// TestHandleConnectorTransitIPRequest tests that if sent a +// request with a transit addr and a destination addr we store that mapping +// and can retrieve it. +func TestHandleConnectorTransitIPRequest(t *testing.T) { - resp := c.handleConnectorTransitIPRequest(nid, req) - if len(resp.TransitIPs) != 0 { - t.Fatalf("n TransitIPs in response: %d, want 0", len(resp.TransitIPs)) - } -} + const appName = "TestApp" -// TestHandleConnectorTransitIPRequestStoresAddr tests that if sent a -// request with a transit addr and a destination addr we store that mapping -// and can retrieve it. If sent another req with a different dst for that transit addr -// we store that instead. -func TestHandleConnectorTransitIPRequestStoresAddr(t *testing.T) { - c := newConn25(logger.Discard) - nid := tailcfg.NodeID(1) - tip := netip.MustParseAddr("0.0.0.1") - dip := netip.MustParseAddr("1.2.3.4") - dip2 := netip.MustParseAddr("1.2.3.5") - mr := func(t, d netip.Addr) ConnectorTransitIPRequest { - return ConnectorTransitIPRequest{ - TransitIPs: []TransitIPRequest{ - {TransitIP: t, DestinationIP: d}, - }, - } - } + // Peer IPs + pipV4_1 := netip.MustParseAddr("100.101.101.101") + pipV4_2 := netip.MustParseAddr("100.101.101.102") - resp := c.handleConnectorTransitIPRequest(nid, mr(tip, dip)) - if len(resp.TransitIPs) != 1 { - t.Fatalf("n TransitIPs in response: %d, want 1", len(resp.TransitIPs)) - } - got := resp.TransitIPs[0].Code - if got != TransitIPResponseCode(0) { - t.Fatalf("TransitIP Code: %d, want 0", got) - } - gotAddr := c.connector.transitIPTarget(nid, tip) - if gotAddr != dip { - t.Fatalf("Connector stored destination for tip: %v, want %v", gotAddr, dip) - } + pipV6_1 := netip.MustParseAddr("fd7a:115c:a1e0::101") + pipV6_3 := netip.MustParseAddr("fd7a:115c:a1e0::103") - // mapping can be overwritten - resp2 := c.handleConnectorTransitIPRequest(nid, mr(tip, dip2)) - if len(resp2.TransitIPs) != 1 { - t.Fatalf("n TransitIPs in response: %d, want 1", len(resp2.TransitIPs)) - } - got2 := resp.TransitIPs[0].Code - if got2 != TransitIPResponseCode(0) { - t.Fatalf("TransitIP Code: %d, want 0", got2) - } - gotAddr2 := c.connector.transitIPTarget(nid, tip) - if gotAddr2 != dip2 { - t.Fatalf("Connector stored destination for tip: %v, want %v", gotAddr, dip2) - } -} + // Transit IPs + tipV4_1 := netip.MustParseAddr("0.0.0.1") + tipV4_2 := netip.MustParseAddr("0.0.0.2") -// TestHandleConnectorTransitIPRequestMultipleTIP tests that we can -// get a req with multiple mappings and we store them all. Including -// multiple transit addrs for the same destination. -func TestHandleConnectorTransitIPRequestMultipleTIP(t *testing.T) { - c := newConn25(logger.Discard) - nid := tailcfg.NodeID(1) - tip := netip.MustParseAddr("0.0.0.1") - tip2 := netip.MustParseAddr("0.0.0.2") - tip3 := netip.MustParseAddr("0.0.0.3") - dip := netip.MustParseAddr("1.2.3.4") - dip2 := netip.MustParseAddr("1.2.3.5") - req := ConnectorTransitIPRequest{ - TransitIPs: []TransitIPRequest{ - {TransitIP: tip, DestinationIP: dip}, - {TransitIP: tip2, DestinationIP: dip2}, - // can store same dst addr for multiple transit addrs - {TransitIP: tip3, DestinationIP: dip}, - }, - } - resp := c.handleConnectorTransitIPRequest(nid, req) - if len(resp.TransitIPs) != 3 { - t.Fatalf("n TransitIPs in response: %d, want 3", len(resp.TransitIPs)) - } - - for i := 0; i < 3; i++ { - got := resp.TransitIPs[i].Code - if got != TransitIPResponseCode(0) { - t.Fatalf("i=%d TransitIP Code: %d, want 0", i, got) - } - } - gotAddr1 := c.connector.transitIPTarget(nid, tip) - if gotAddr1 != dip { - t.Fatalf("Connector stored destination for tip(%v): %v, want %v", tip, gotAddr1, dip) - } - gotAddr2 := c.connector.transitIPTarget(nid, tip2) - if gotAddr2 != dip2 { - t.Fatalf("Connector stored destination for tip(%v): %v, want %v", tip2, gotAddr2, dip2) - } - gotAddr3 := c.connector.transitIPTarget(nid, tip3) - if gotAddr3 != dip { - t.Fatalf("Connector stored destination for tip(%v): %v, want %v", tip3, gotAddr3, dip) - } -} + tipV6_1 := netip.MustParseAddr("FE80::1") -// TestHandleConnectorTransitIPRequestSameTIP tests that if we get -// a req that has more than one TransitIPRequest for the same transit addr -// only the first is stored, and the subsequent ones get an error code and -// message in the response. -func TestHandleConnectorTransitIPRequestSameTIP(t *testing.T) { - c := newConn25(logger.Discard) - nid := tailcfg.NodeID(1) - tip := netip.MustParseAddr("0.0.0.1") - tip2 := netip.MustParseAddr("0.0.0.2") - dip := netip.MustParseAddr("1.2.3.4") - dip2 := netip.MustParseAddr("1.2.3.5") - dip3 := netip.MustParseAddr("1.2.3.6") - req := ConnectorTransitIPRequest{ - TransitIPs: []TransitIPRequest{ - {TransitIP: tip, DestinationIP: dip}, - // cannot have dupe TransitIPs in one ConnectorTransitIPRequest - {TransitIP: tip, DestinationIP: dip2}, - {TransitIP: tip2, DestinationIP: dip3}, + // Destination IPs + dipV4_1 := netip.MustParseAddr("10.0.0.1") + dipV4_2 := netip.MustParseAddr("10.0.0.2") + dipV4_3 := netip.MustParseAddr("10.0.0.3") + + dipV6_1 := netip.MustParseAddr("fc00::1") + + // Peer nodes + peerV4V6 := (&tailcfg.Node{ + ID: tailcfg.NodeID(1), + Addresses: []netip.Prefix{netip.PrefixFrom(pipV4_1, 32), netip.PrefixFrom(pipV6_1, 128)}, + }).View() + + peerV4Only := (&tailcfg.Node{ + ID: tailcfg.NodeID(2), + Addresses: []netip.Prefix{netip.PrefixFrom(pipV4_2, 32)}, + }).View() + + peerV6Only := (&tailcfg.Node{ + ID: tailcfg.NodeID(3), + Addresses: []netip.Prefix{netip.PrefixFrom(pipV6_3, 128)}, + }).View() + + tests := []struct { + name string + ctipReqPeers []tailcfg.NodeView // One entry per request and the other + ctipReqs []ConnectorTransitIPRequest // arrays in this struct must have the same + wants []ConnectorTransitIPResponse // cardinality + // For checking lookups: + // The outer array needs to correspond to the number of requests, + // can be nil if no lookups need to be done after the request is processed. + // + // The middle array is the set of lookups for the corresponding request. + // + // The inner array is a tuple of (PeerIP, TransitIP, ExpectedDestinationIP) + wantLookups [][][]netip.Addr + }{ + // Single peer, single request with success ipV4 + { + name: "one-peer-one-req-ipv4", + ctipReqPeers: []tailcfg.NodeView{peerV4Only}, + ctipReqs: []ConnectorTransitIPRequest{ + {TransitIPs: []TransitIPRequest{{TransitIP: tipV4_1, DestinationIP: dipV4_1, App: appName}}}, + }, + wants: []ConnectorTransitIPResponse{ + {TransitIPs: []TransitIPResponse{{Code: OK, Message: ""}}}, + }, + wantLookups: [][][]netip.Addr{ + {{pipV4_2, tipV4_1, dipV4_1}}, + }, + }, + // Single peer, single request with success ipV6 + { + name: "one-peer-one-req-ipv6", + ctipReqPeers: []tailcfg.NodeView{peerV6Only}, + ctipReqs: []ConnectorTransitIPRequest{ + {TransitIPs: []TransitIPRequest{{TransitIP: tipV6_1, DestinationIP: dipV6_1, App: appName}}}, + }, + wants: []ConnectorTransitIPResponse{ + {TransitIPs: []TransitIPResponse{{Code: OK, Message: ""}}}, + }, + wantLookups: [][][]netip.Addr{ + {{pipV6_3, tipV6_1, dipV6_1}}, + }, + }, + // Single peer, multi request with success, ipV4 + { + name: "one-peer-multi-req-ipv4", + ctipReqPeers: []tailcfg.NodeView{peerV4Only, peerV4Only}, + ctipReqs: []ConnectorTransitIPRequest{ + {TransitIPs: []TransitIPRequest{{TransitIP: tipV4_1, DestinationIP: dipV4_1, App: appName}}}, + {TransitIPs: []TransitIPRequest{{TransitIP: tipV4_2, DestinationIP: dipV4_2, App: appName}}}, + }, + wants: []ConnectorTransitIPResponse{ + {TransitIPs: []TransitIPResponse{{Code: OK, Message: ""}}}, + {TransitIPs: []TransitIPResponse{{Code: OK, Message: ""}}}, + }, + wantLookups: [][][]netip.Addr{ + {{pipV4_2, tipV4_1, dipV4_1}}, + {{pipV4_2, tipV4_2, dipV4_2}}, + }, + }, + // Single peer, multi request remap tip, ipV4 + { + name: "one-peer-remap-tip", + ctipReqPeers: []tailcfg.NodeView{peerV4Only, peerV4Only}, + ctipReqs: []ConnectorTransitIPRequest{ + {TransitIPs: []TransitIPRequest{{TransitIP: tipV4_1, DestinationIP: dipV4_1, App: appName}}}, + {TransitIPs: []TransitIPRequest{{TransitIP: tipV4_1, DestinationIP: dipV4_2, App: appName}}}, + }, + wants: []ConnectorTransitIPResponse{ + {TransitIPs: []TransitIPResponse{{Code: OK, Message: ""}}}, + {TransitIPs: []TransitIPResponse{{Code: OK, Message: ""}}}, + }, + wantLookups: [][][]netip.Addr{ + {{pipV4_2, tipV4_1, dipV4_1}}, + {{pipV4_2, tipV4_1, dipV4_2}}, + }, + }, + // Single peer, multi request with success, ipV4 and ipV6 + { + name: "one-peer-multi-req-ipv4-ipv6", + ctipReqPeers: []tailcfg.NodeView{peerV4V6, peerV4V6}, + ctipReqs: []ConnectorTransitIPRequest{ + {TransitIPs: []TransitIPRequest{{TransitIP: tipV4_1, DestinationIP: dipV4_1, App: appName}}}, + {TransitIPs: []TransitIPRequest{{TransitIP: tipV6_1, DestinationIP: dipV6_1, App: appName}}}, + }, + wants: []ConnectorTransitIPResponse{ + {TransitIPs: []TransitIPResponse{{Code: OK, Message: ""}}}, + {TransitIPs: []TransitIPResponse{{Code: OK, Message: ""}}}, + }, + wantLookups: [][][]netip.Addr{ + {{pipV4_1, tipV4_1, dipV4_1}}, + {{pipV4_1, tipV4_1, dipV4_1}, {pipV6_1, tipV6_1, dipV6_1}, {pipV4_1, tipV6_1, netip.Addr{}}}, + }, + }, + // Single peer, multi map with success, ipV4 + { + name: "one-peer-multi-map-ipv4", + ctipReqPeers: []tailcfg.NodeView{peerV4Only}, + ctipReqs: []ConnectorTransitIPRequest{ + {TransitIPs: []TransitIPRequest{ + {TransitIP: tipV4_1, DestinationIP: dipV4_1, App: appName}, + {TransitIP: tipV4_2, DestinationIP: dipV4_2, App: appName}, + }}, + }, + wants: []ConnectorTransitIPResponse{ + {TransitIPs: []TransitIPResponse{{Code: OK, Message: ""}, {Code: OK, Message: ""}}}, + }, + wantLookups: [][][]netip.Addr{ + {{pipV4_2, tipV4_1, dipV4_1}, {pipV4_2, tipV4_2, dipV4_2}}, + }, + }, + // Single peer, error reuse same tip in one request, ensure all non-dup requests are processed + { + name: "one-peer-multi-map-duplicate-tip", + ctipReqPeers: []tailcfg.NodeView{peerV4Only}, + ctipReqs: []ConnectorTransitIPRequest{ + {TransitIPs: []TransitIPRequest{ + {TransitIP: tipV4_1, DestinationIP: dipV4_1, App: appName}, + {TransitIP: tipV4_1, DestinationIP: dipV4_2, App: appName}, + {TransitIP: tipV4_2, DestinationIP: dipV4_3, App: appName}, + }}, + }, + wants: []ConnectorTransitIPResponse{ + {TransitIPs: []TransitIPResponse{ + {Code: OK, Message: ""}, + {Code: DuplicateTransitIP, Message: dupeTransitIPMessage}, + {Code: OK, Message: ""}}, + }, + }, + wantLookups: [][][]netip.Addr{ + {{pipV4_2, tipV4_1, dipV4_1}, {pipV4_2, tipV4_2, dipV4_3}}, + }, + }, + // Multi peer, success reuse same tip in one request + { + name: "multi-peer-duplicate-tip", + ctipReqPeers: []tailcfg.NodeView{peerV4V6, peerV4Only}, + ctipReqs: []ConnectorTransitIPRequest{ + {TransitIPs: []TransitIPRequest{{TransitIP: tipV4_1, DestinationIP: dipV4_1, App: appName}}}, + {TransitIPs: []TransitIPRequest{{TransitIP: tipV4_1, DestinationIP: dipV4_2, App: appName}}}, + }, + wants: []ConnectorTransitIPResponse{ + {TransitIPs: []TransitIPResponse{{Code: OK, Message: ""}}}, + {TransitIPs: []TransitIPResponse{{Code: OK, Message: ""}}}, + }, + wantLookups: [][][]netip.Addr{ + {{pipV4_1, tipV4_1, dipV4_1}}, + {{pipV4_1, tipV4_1, dipV4_1}, {pipV4_2, tipV4_1, dipV4_2}}, + }, + }, + // Single peer, multi map, multiple tip to same dip + { + name: "one-peer-multi-map-multi-tip-to-dip", + ctipReqPeers: []tailcfg.NodeView{peerV4Only}, + ctipReqs: []ConnectorTransitIPRequest{ + {TransitIPs: []TransitIPRequest{ + {TransitIP: tipV4_1, DestinationIP: dipV4_1, App: appName}, + {TransitIP: tipV4_2, DestinationIP: dipV4_1, App: appName}, + }}, + }, + wants: []ConnectorTransitIPResponse{ + {TransitIPs: []TransitIPResponse{{Code: OK, Message: ""}, {Code: OK, Message: ""}}}, + }, + wantLookups: [][][]netip.Addr{ + {{pipV4_2, tipV4_1, dipV4_1}, {pipV4_2, tipV4_2, dipV4_1}}, + }, + }, + // Single peer, ipv4 tip, no ipv4 pip, but ipv6 tip works + { + name: "one-peer-missing-ipv4-family", + ctipReqPeers: []tailcfg.NodeView{peerV6Only}, + ctipReqs: []ConnectorTransitIPRequest{ + {TransitIPs: []TransitIPRequest{ + {TransitIP: tipV4_1, DestinationIP: dipV4_1, App: appName}, + {TransitIP: tipV6_1, DestinationIP: dipV6_1, App: appName}, + }}, + }, + wants: []ConnectorTransitIPResponse{ + {TransitIPs: []TransitIPResponse{ + {Code: NoMatchingPeerIPFamily, Message: noMatchingPeerIPFamilyMessage}, + {Code: OK, Message: ""}, + }}, + }, + wantLookups: [][][]netip.Addr{ + {{pipV6_3, tipV4_1, netip.Addr{}}, {pipV6_3, tipV6_1, dipV6_1}}, + }, + }, + // Single peer, ipv6 tip, no ipv6 pip, but ipv4 tip works + { + name: "one-peer-missing-ipv6-family", + ctipReqPeers: []tailcfg.NodeView{peerV4Only}, + ctipReqs: []ConnectorTransitIPRequest{ + {TransitIPs: []TransitIPRequest{ + {TransitIP: tipV6_1, DestinationIP: dipV6_1, App: appName}, + {TransitIP: tipV4_1, DestinationIP: dipV4_1, App: appName}, + }}, + }, + wants: []ConnectorTransitIPResponse{ + {TransitIPs: []TransitIPResponse{ + {Code: NoMatchingPeerIPFamily, Message: noMatchingPeerIPFamilyMessage}, + {Code: OK, Message: ""}, + }}, + }, + wantLookups: [][][]netip.Addr{ + {{pipV4_2, tipV6_1, netip.Addr{}}, {pipV4_2, tipV4_1, dipV4_1}}, + }, + }, + // Single peer, mismatched transit and destination ips + { + name: "one-peer-mismatched-tip-dip", + ctipReqPeers: []tailcfg.NodeView{peerV4Only}, + ctipReqs: []ConnectorTransitIPRequest{ + {TransitIPs: []TransitIPRequest{{TransitIP: tipV4_1, DestinationIP: dipV6_1, App: appName}}}, + }, + wants: []ConnectorTransitIPResponse{ + {TransitIPs: []TransitIPResponse{{Code: AddrFamilyMismatch, Message: addrFamilyMismatchMessage}}}, + }, + wantLookups: [][][]netip.Addr{ + {{pipV4_2, tipV4_1, netip.Addr{}}}, + }, + }, + // Single peer, invalid app name + { + name: "one-peer-invalid-app", + ctipReqPeers: []tailcfg.NodeView{peerV4Only}, + ctipReqs: []ConnectorTransitIPRequest{ + {TransitIPs: []TransitIPRequest{{TransitIP: tipV4_1, DestinationIP: dipV4_1, App: "Unknown App"}}}, + }, + wants: []ConnectorTransitIPResponse{ + {TransitIPs: []TransitIPResponse{{Code: UnknownAppName, Message: unknownAppNameMessage}}}, + }, + wantLookups: [][][]netip.Addr{ + {{pipV4_2, tipV4_1, netip.Addr{}}}, + }, }, } - resp := c.handleConnectorTransitIPRequest(nid, req) - if len(resp.TransitIPs) != 3 { - t.Fatalf("n TransitIPs in response: %d, want 3", len(resp.TransitIPs)) - } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + switch { + case len(tt.ctipReqPeers) != len(tt.ctipReqs): + t.Fatalf("error in test setup: ctipReqPeers has length %d does not match ctipReqs length %d", + len(tt.ctipReqPeers), len(tt.ctipReqs)) + case len(tt.ctipReqPeers) != len(tt.wants): + t.Fatalf("error in test setup: ctipReqPeers has length %d does not match wants length %d", + len(tt.ctipReqPeers), len(tt.wants)) + case len(tt.ctipReqPeers) != len(tt.wantLookups): + t.Fatalf("error in test setup: ctipReqPeers has length %d does not match wantLookups length %d", + len(tt.ctipReqPeers), len(tt.wantLookups)) + } - got := resp.TransitIPs[0].Code - if got != TransitIPResponseCode(0) { - t.Fatalf("i=0 TransitIP Code: %d, want 0", got) - } - msg := resp.TransitIPs[0].Message - if msg != "" { - t.Fatalf("i=0 TransitIP Message: \"%s\", want \"%s\"", msg, "") - } - got1 := resp.TransitIPs[1].Code - if got1 != TransitIPResponseCode(1) { - t.Fatalf("i=1 TransitIP Code: %d, want 1", got1) - } - msg1 := resp.TransitIPs[1].Message - if msg1 != dupeTransitIPMessage { - t.Fatalf("i=1 TransitIP Message: \"%s\", want \"%s\"", msg1, dupeTransitIPMessage) - } - got2 := resp.TransitIPs[2].Code - if got2 != TransitIPResponseCode(0) { - t.Fatalf("i=2 TransitIP Code: %d, want 0", got2) - } - msg2 := resp.TransitIPs[2].Message - if msg2 != "" { - t.Fatalf("i=2 TransitIP Message: \"%s\", want \"%s\"", msg, "") - } + // Use the same Conn25 for each request in the test and seed it with a test app name. + c := newConn25(logger.Discard) + c.reconfig(&config{ + isConfigured: true, + appsByName: map[string]appctype.Conn25Attr{appName: {}}, + }) - gotAddr1 := c.connector.transitIPTarget(nid, tip) - if gotAddr1 != dip { - t.Fatalf("Connector stored destination for tip(%v): %v, want %v", tip, gotAddr1, dip) - } - gotAddr2 := c.connector.transitIPTarget(nid, tip2) - if gotAddr2 != dip3 { - t.Fatalf("Connector stored destination for tip(%v): %v, want %v", tip2, gotAddr2, dip3) - } -} + for i, peer := range tt.ctipReqPeers { + req := tt.ctipReqs[i] + want := tt.wants[i] -// TestGetDstIPUnknownTIP tests that unknown transit addresses can be looked up without problem. -func TestTransitIPTargetUnknownTIP(t *testing.T) { - c := newConn25(logger.Discard) - nid := tailcfg.NodeID(1) - tip := netip.MustParseAddr("0.0.0.1") - got := c.connector.transitIPTarget(nid, tip) - want := netip.Addr{} - if got != want { - t.Fatalf("Unknown transit addr, want: %v, got %v", want, got) + resp := c.handleConnectorTransitIPRequest(peer, req) + + // Ensure that we have the expected number of responses + if len(resp.TransitIPs) != len(want.TransitIPs) { + t.Fatalf("wrong number of TransitIPs in response %d: got %d, want %d", + i, len(resp.TransitIPs), len(want.TransitIPs)) + } + + // Validate the contents of each response + for j, tipResp := range resp.TransitIPs { + wantResp := want.TransitIPs[j] + if tipResp.Code != wantResp.Code { + t.Errorf("transitIP.Code mismatch in response %d, tipresp %d: got %d, want %d", + i, j, tipResp.Code, wantResp.Code) + } + if tipResp.Message != wantResp.Message { + t.Errorf("transitIP.Message mismatch in response %d, tipresp %d: got %q, want %q", + i, j, tipResp.Message, wantResp.Message) + } + } + + // Validate the state of the transitIP map after each request + if tt.wantLookups[i] != nil { + for j, wantLookup := range tt.wantLookups[i] { + if len(wantLookup) != 3 { + t.Fatalf("test setup error: wantLookup for request %d lookup %d contains %d IPs, expected 3", + i, j, len(wantLookup)) + } + pip, tip, wantDip := wantLookup[0], wantLookup[1], wantLookup[2] + aa, _ := c.connector.lookupBySrcIPAndTransitIP(pip, tip) + gotDip := aa.addr + if gotDip != wantDip { + t.Errorf("wrong result on lookup[%d][%d] ([%v], [%v]): got [%v] expected [%v]", + i, j, pip, tip, gotDip, wantDip) + } + } + } + } + }) } } func TestReserveIPs(t *testing.T) { c := newConn25(logger.Discard) - c.client.magicIPPool = newIPPool(mustIPSetFromPrefix("100.64.0.0/24")) - c.client.transitIPPool = newIPPool(mustIPSetFromPrefix("169.254.0.0/24")) - mbd := map[dnsname.FQDN][]string{} - mbd["example.com."] = []string{"a"} - c.client.config.appsByDomain = mbd - - dst := netip.MustParseAddr("0.0.0.1") - addrs, err := c.client.reserveAddresses("example.com.", dst) - if err != nil { - t.Fatal(err) + const appName = "a" + domainStr := "example.com." + cfg := &config{ + isConfigured: true, + appsByName: map[string]appctype.Conn25Attr{appName: {}}, + ipSets: ipSets{ + v4Magic: mustIPSetFromPrefix("100.64.0.0/24"), + v6Magic: mustIPSetFromPrefix("fd7a:115c:a1e0:a99c:0100::/80"), + v4Transit: mustIPSetFromPrefix("169.254.0.0/24"), + v6Transit: mustIPSetFromPrefix("fd7a:115c:a1e0:a99c:0200::/80"), + }, } + c.reconfig(cfg) + domain := must.Get(dnsname.ToFQDN(domainStr)) - wantDst := netip.MustParseAddr("0.0.0.1") // same as dst we pass in - wantMagic := netip.MustParseAddr("100.64.0.0") // first from magic pool - wantTransit := netip.MustParseAddr("169.254.0.0") // first from transit pool - wantApp := "a" // the app name related to example.com. - wantDomain := must.Get(dnsname.ToFQDN("example.com.")) - - if wantDst != addrs.dst { - t.Errorf("want %v, got %v", wantDst, addrs.dst) - } - if wantMagic != addrs.magic { - t.Errorf("want %v, got %v", wantMagic, addrs.magic) - } - if wantTransit != addrs.transit { - t.Errorf("want %v, got %v", wantTransit, addrs.transit) - } - if wantApp != addrs.app { - t.Errorf("want %s, got %s", wantApp, addrs.app) - } - if wantDomain != addrs.domain { - t.Errorf("want %s, got %s", wantDomain, addrs.domain) + for _, tt := range []struct { + name string + dst netip.Addr + wantMagic netip.Addr + wantTransit netip.Addr + }{ + { + name: "v4", + dst: netip.MustParseAddr("0.0.0.1"), + wantMagic: netip.MustParseAddr("100.64.0.0"), // first from magic pool + wantTransit: netip.MustParseAddr("169.254.0.0"), // first from transit pool + }, + { + name: "v6", + dst: netip.MustParseAddr("::1"), + wantMagic: netip.MustParseAddr("fd7a:115c:a1e0:a99c:100::"), // first from magic pool + wantTransit: netip.MustParseAddr("fd7a:115c:a1e0:a99c:200::"), // first from transit pool + }, + } { + t.Run(tt.name, func(t *testing.T) { + addrs, err := c.client.reserveAddresses(appName, domain, tt.dst) + if err != nil { + t.Fatal(err) + } + if tt.dst != addrs.dst { + t.Errorf("want %v, got %v", tt.dst, addrs.dst) + } + if tt.wantMagic != addrs.magic { + t.Errorf("want %v, got %v", tt.wantMagic, addrs.magic) + } + if tt.wantTransit != addrs.transit { + t.Errorf("want %v, got %v", tt.wantTransit, addrs.transit) + } + if appName != addrs.app { + t.Errorf("want %s, got %s", appName, addrs.app) + } + if domain != addrs.domain { + t.Errorf("want %s, got %s", domain, addrs.domain) + } + }) } } @@ -254,29 +460,40 @@ func TestReconfig(t *testing.T) { } c := newConn25(logger.Discard) + if c.isConfigured() { + t.Fatal("expected Conn25 isConfigured() to report unconfigured before reconfig") + } + sn := (&tailcfg.Node{ CapMap: capMap, }).View() + cfg := mustConfig(t, sn) + c.reconfig(cfg) - err := c.reconfig(sn) - if err != nil { - t.Fatal(err) + if !c.isConfigured() { + t.Fatal("expected Conn25 isConfigured() to report configured after reconfig") + } + + cfg, ok := c.getConfig() + if !ok { + t.Fatal("expected Conn25 getConfig() to report configured after reconfig") } - if len(c.client.config.apps) != 1 || c.client.config.apps[0].Name != "app1" { - t.Fatalf("want apps to have one entry 'app1', got %v", c.client.config.apps) + if len(cfg.apps) != 1 || cfg.apps[0].Name != "app1" { + t.Fatalf("want apps to have one entry 'app1', got %v", cfg.apps) } } -func TestConfigReconfig(t *testing.T) { +func TestConfigFromNodeView(t *testing.T) { for _, tt := range []struct { - name string - rawCfg string - cfg []appctype.Conn25Attr - tags []string - wantErr bool - wantAppsByDomain map[dnsname.FQDN][]string - wantSelfRoutedDomains set.Set[dnsname.FQDN] + name string + rawCfg string + cfg []appctype.Conn25Attr + tags []string + wantErr bool + wantAppsByDomain map[dnsname.FQDN][]string + wantAppsByWCDomain map[dnsname.FQDN][]string + wantSelfAppNames set.Set[string] }{ { name: "bad-config", @@ -294,10 +511,11 @@ func TestConfigReconfig(t *testing.T) { "a.example.com.": {"one"}, "b.example.com.": {"two"}, }, - wantSelfRoutedDomains: set.SetOf([]dnsname.FQDN{"a.example.com."}), + wantAppsByWCDomain: map[dnsname.FQDN][]string{}, + wantSelfAppNames: set.SetOf([]string{"one"}), }, { - name: "more-complex", + name: "more-complex-with-connector-self-domains", cfg: []appctype.Conn25Attr{ {Name: "one", Domains: []string{"1.a.example.com", "1.b.example.com"}, Connectors: []string{"tag:one", "tag:onea"}}, {Name: "two", Domains: []string{"2.b.example.com", "2.c.example.com"}, Connectors: []string{"tag:two", "tag:twoa"}}, @@ -314,7 +532,63 @@ func TestConfigReconfig(t *testing.T) { "4.b.example.com.": {"four"}, "4.d.example.com.": {"four"}, }, - wantSelfRoutedDomains: set.SetOf([]dnsname.FQDN{"1.a.example.com.", "1.b.example.com.", "4.b.example.com.", "4.d.example.com."}), + wantAppsByWCDomain: map[dnsname.FQDN][]string{}, + wantSelfAppNames: set.SetOf([]string{"one", "four"}), + }, + { + name: "eligible-connector-no-matching-tag-no-self-domains", + cfg: []appctype.Conn25Attr{ + {Name: "one", Domains: []string{"a.example.com"}, Connectors: []string{"tag:one"}}, + {Name: "two", Domains: []string{"b.example.com"}, Connectors: []string{"tag:two"}}, + }, + tags: []string{"tag:unrelated"}, + wantAppsByDomain: map[dnsname.FQDN][]string{ + "a.example.com.": {"one"}, + "b.example.com.": {"two"}, + }, + wantAppsByWCDomain: map[dnsname.FQDN][]string{}}, + { + name: "wildcard-collapse-and-deduplication", + cfg: []appctype.Conn25Attr{ + {Name: "one", Domains: []string{"*.example.com", "example.com"}, Connectors: []string{"tag:one"}}, + {Name: "two", Domains: []string{"example.com", "sub.example.com"}, Connectors: []string{"tag:two"}}, + }, + tags: []string{"tag:one", "tag:two"}, + wantAppsByDomain: map[dnsname.FQDN][]string{ + "example.com.": {"one", "two"}, + "sub.example.com.": {"two"}, + }, + wantAppsByWCDomain: map[dnsname.FQDN][]string{ + "example.com.": {"one"}, + }, + wantSelfAppNames: set.SetOf([]string{"one", "two"}), + }, + { + // Domain names that differ only in case must be treated as the same + // domain and the app name must appear exactly once in appNamesByDomain, + // not once per case variant. + name: "case-variant-exact-domains-deduplicated-within-app", + cfg: []appctype.Conn25Attr{ + {Name: "one", Domains: []string{"EXAMPLE.com", "example.COM", "Example.COM"}, Connectors: []string{"tag:one"}}, + }, + tags: []string{"tag:one"}, + wantAppsByDomain: map[dnsname.FQDN][]string{ + "example.com.": {"one"}, + }, + wantAppsByWCDomain: map[dnsname.FQDN][]string{}, + wantSelfAppNames: set.SetOf([]string{"one"}), + }, + { + // Same as above but for wildcard domains: *.EXAMPLE.com and *.example.COM + // must collapse to a single entry in appNamesByWCDomain. + name: "case-variant-wildcard-domains-deduplicated-within-app", + cfg: []appctype.Conn25Attr{ + {Name: "one", Domains: []string{"*.EXAMPLE.com", "*.example.COM"}, Connectors: []string{"tag:one"}}, + }, + tags: []string{"tag:one"}, + wantAppsByDomain: map[dnsname.FQDN][]string{}, + wantAppsByWCDomain: map[dnsname.FQDN][]string{"example.com.": {"one"}}, + wantSelfAppNames: set.SetOf([]string{"one"}), }, } { t.Run(tt.name, func(t *testing.T) { @@ -336,98 +610,312 @@ func TestConfigReconfig(t *testing.T) { CapMap: capMap, Tags: tt.tags, }).View() + c, err := configFromNodeView(sn) if (err != nil) != tt.wantErr { t.Fatalf("wantErr: %t, err: %v", tt.wantErr, err) } - if diff := cmp.Diff(tt.wantAppsByDomain, c.appsByDomain); diff != "" { + if diff := cmp.Diff(tt.wantAppsByDomain, c.appNamesByDomain); diff != "" { t.Errorf("appsByDomain diff (-want, +got):\n%s", diff) } - if diff := cmp.Diff(tt.wantSelfRoutedDomains, c.selfRoutedDomains); diff != "" { - t.Errorf("selfRoutedDomains diff (-want, +got):\n%s", diff) + if diff := cmp.Diff(tt.wantAppsByWCDomain, c.appNamesByWCDomain); diff != "" { + t.Errorf("appsByWCDomain diff (-want, +got):\n%s", diff) + } + if diff := cmp.Diff(tt.wantSelfAppNames, c.selfAppNames); diff != "" { + t.Errorf("selfAppNames diff (-want, +got):\n%s", diff) + } + }) + } +} + +func TestGetAppsForDomainName(t *testing.T) { + defaultSN := makeSelfNode( + t, + []appctype.Conn25Attr{ + {Name: "one", Domains: []string{"*.example.com", "example.com"}, Connectors: []string{"tag:one"}}, + {Name: "two", Domains: []string{"sub.example.com", "example.com"}, Connectors: []string{"tag:two"}}, + {Name: "three", Domains: []string{"*.sub.example.com"}, Connectors: []string{"tag:three"}}, + {Name: "four", Domains: []string{"a.sub.example.com"}, Connectors: []string{"tag:four"}}, + {Name: "self-routed", Domains: []string{"*.wildcard.com", "exact-match.com"}, Connectors: []string{"tag:self-routed"}}, + }, + []string{"tag:self-routed"}, + ) + + for _, tt := range []struct { + name string + isConnector bool + domain dnsname.FQDN + wantApps []string + }{ + { + name: "no-match", + domain: "nomatch.com.", + wantApps: nil, + }, + { + name: "exact-match", + domain: "example.com.", + wantApps: []string{"one", "two"}, + }, + { + name: "wildcard-subdomain-match", + domain: "a.example.com.", + wantApps: []string{"one"}, + }, + { + name: "exact-subdomain-match", + domain: "sub.example.com.", + wantApps: []string{"two"}, + }, + { + name: "wildcard-sub-of-subdomain-match", + domain: "b.sub.example.com.", + wantApps: []string{"three"}, + }, + { + name: "exact-sub-of-subdomain-match", + domain: "a.sub.example.com.", + wantApps: []string{"four"}, + }, + { + name: "exact-domain-matches-wildcard", + domain: "wildcard.com.", + wantApps: []string{"self-routed"}, + }, + { + name: "self-routed-exact-domain-suppressed", + isConnector: true, + domain: "exact-match.com.", + wantApps: nil, + }, + { + // Self node is an eligible connector for "wildcard-self-app" via + // *.wildcard.com, so the wildcard match must also be suppressed. + name: "self-routed-wildcard-domain-suppressed", + isConnector: true, + domain: "sub.wildcard.com.", + wantApps: nil, + }, + { + // "other-app" is not on a self-connector tag, so it must not be suppressed. + name: "non-self-routed-domain-not-suppressed", + isConnector: true, + domain: "example.com.", + wantApps: []string{"one", "two"}, + }, + { + // Even though the app's connector tag matches the self node's tags, + // if the node is not an eligible connector (Advertise=false) then + // isSelfRoutedApp returns false and the domain is forwarded normally. + name: "not-eligible-connector-not-suppressed", + domain: "exact-match.com.", + wantApps: []string{"self-routed"}, + }, + } { + t.Run(tt.name, func(t *testing.T) { + c := newConn25(logger.Discard) + if tt.isConnector { + c.prefsAdvertiseConnector.Store(true) + } + cfg := mustConfig(t, defaultSN) + c.reconfig(cfg) + cfg, ok := c.getConfig() + if !ok { + t.Fatal("could not get config") + } + gotApps := cfg.getAppsForConnectorDomain(tt.domain, tt.isConnector) + if diff := cmp.Diff(tt.wantApps, gotApps); diff != "" { + t.Errorf("unexpected appNames result: diff (-want, +got):\n%s", diff) } }) } } -func makeSelfNode(t *testing.T, attr appctype.Conn25Attr, tags []string) tailcfg.NodeView { +func makeSelfNode(t *testing.T, attrs []appctype.Conn25Attr, tags []string) tailcfg.NodeView { t.Helper() - bs, err := json.Marshal(attr) - if err != nil { - t.Fatalf("unexpected error in test setup: %v", err) + cfg := make([]tailcfg.RawMessage, 0, len(attrs)) + for i, attr := range attrs { + bs, err := json.Marshal(attr) + if err != nil { + t.Fatalf("unexpected error in test setup at index %d: %v", i, err) + } + cfg = append(cfg, tailcfg.RawMessage(bs)) } - cfg := []tailcfg.RawMessage{tailcfg.RawMessage(bs)} capMap := tailcfg.NodeCapMap{ tailcfg.NodeCapability(AppConnectorsExperimentalAttrName): cfg, } + return (&tailcfg.Node{ CapMap: capMap, Tags: tags, }).View() } -func rangeFrom(from, to string) netipx.IPRange { +var ( + testPrefsNotConnector = (&ipn.Prefs{AppConnector: ipn.AppConnectorPrefs{Advertise: false}}).View() +) + +func mustConfig(t *testing.T, selfNode tailcfg.NodeView) *config { + t.Helper() + cfg, err := configFromNodeView(selfNode) + if err != nil { + t.Fatal(err) + } + return cfg +} + +func v4RangeFrom(from, to string) netipx.IPRange { return netipx.IPRangeFrom( netip.MustParseAddr("100.64.0."+from), netip.MustParseAddr("100.64.0."+to), ) } -func TestMapDNSResponse(t *testing.T) { - makeDNSResponse := func(domain string, addrs []dnsmessage.AResource) []byte { - b := dnsmessage.NewBuilder(nil, - dnsmessage.Header{ - ID: 1, - Response: true, - Authoritative: true, - RCode: dnsmessage.RCodeSuccess, - }) - b.EnableCompression() - - if err := b.StartQuestions(); err != nil { - t.Fatal(err) - } +func v6RangeFrom(from, to string) netipx.IPRange { + return netipx.IPRangeFrom( + netip.MustParseAddr("fd7a:115c:a1e0:a99c:"+from+"::"), + netip.MustParseAddr("fd7a:115c:a1e0:a99c:"+to+"::"), + ) +} - if err := b.Question(dnsmessage.Question{ - Name: dnsmessage.MustNewName(domain), +func makeDNSResponse(t *testing.T, domain string, addrs []*dnsmessage.AResource) []byte { + t.Helper() + name := dnsmessage.MustNewName(domain) + questions := []dnsmessage.Question{ + { + Name: name, Type: dnsmessage.TypeA, Class: dnsmessage.ClassINET, - }); err != nil { - t.Fatal(err) - } - - if err := b.StartAnswers(); err != nil { - t.Fatal(err) - } - - for _, addr := range addrs { - b.AResource( - dnsmessage.ResourceHeader{ - Name: dnsmessage.MustNewName(domain), - Type: dnsmessage.TypeA, - Class: dnsmessage.ClassINET, - }, - addr, - ) - } - - outbs, err := b.Finish() - if err != nil { - t.Fatal(err) + }, + } + var answers []dnsmessage.Resource + for _, addr := range addrs { + ans := dnsmessage.Resource{ + Header: dnsmessage.ResourceHeader{ + Name: name, + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }, + Body: addr, } - return outbs + answers = append(answers, ans) + } + additional := []dnsmessage.Resource{ + { + Header: dnsmessage.ResourceHeader{ + Name: name, + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }, + Body: &dnsmessage.AResource{A: [4]byte{9, 9, 9, 9}}, + }, + } + return makeDNSResponseForSections(t, questions, answers, additional) +} + +func makeV6DNSResponse(t *testing.T, domain string, addrs []*dnsmessage.AAAAResource) []byte { + t.Helper() + name := dnsmessage.MustNewName(domain) + questions := []dnsmessage.Question{ + { + Name: name, + Type: dnsmessage.TypeAAAA, + Class: dnsmessage.ClassINET, + }, + } + var answers []dnsmessage.Resource + for _, addr := range addrs { + ans := dnsmessage.Resource{ + Header: dnsmessage.ResourceHeader{ + Name: name, + Type: dnsmessage.TypeAAAA, + Class: dnsmessage.ClassINET, + }, + Body: addr, + } + answers = append(answers, ans) + } + return makeDNSResponseForSections(t, questions, answers, nil) +} + +func makeDNSResponseForSections(t *testing.T, questions []dnsmessage.Question, answers []dnsmessage.Resource, additional []dnsmessage.Resource) []byte { + t.Helper() + b := dnsmessage.NewBuilder(nil, + dnsmessage.Header{ + ID: 1, + Response: true, + Authoritative: true, + RCode: dnsmessage.RCodeSuccess, + }) + b.EnableCompression() + + if err := b.StartQuestions(); err != nil { + t.Fatal(err) + } + + for _, q := range questions { + if err := b.Question(q); err != nil { + t.Fatal(err) + } + } + + if err := b.StartAnswers(); err != nil { + t.Fatal(err) } + for _, ans := range answers { + switch ans.Header.Type { + case dnsmessage.TypeA: + body, ok := (ans.Body).(*dnsmessage.AResource) + if !ok { + t.Fatalf("unexpected answer type, update test") + } + b.AResource(ans.Header, *body) + case dnsmessage.TypeAAAA: + body, ok := (ans.Body).(*dnsmessage.AAAAResource) + if !ok { + t.Fatalf("unexpected answer type, update test") + } + b.AAAAResource(ans.Header, *body) + default: + t.Fatalf("unhandled answer type, update test: %v", ans.Header.Type) + } + } + + if err := b.StartAdditionals(); err != nil { + t.Fatal(err) + } + for _, add := range additional { + body, ok := (add.Body).(*dnsmessage.AResource) + if !ok { + t.Fatalf("unexpected additional type, update test") + } + b.AResource(add.Header, *body) + } + + outbs, err := b.Finish() + if err != nil { + t.Fatal(err) + } + return outbs +} + +func TestMapDNSResponseAssignsAddrs(t *testing.T) { for _, tt := range []struct { - name string - domain string - addrs []dnsmessage.AResource - wantByMagicIP map[netip.Addr]addrs + name string + appDomains []string + domain string + v4Addrs []*dnsmessage.AResource + v6Addrs []*dnsmessage.AAAAResource + selfTags []string + isEligibleConnector bool + wantByMagicIP map[netip.Addr]addrs }{ { - name: "one-ip-matches", - domain: "example.com.", - addrs: []dnsmessage.AResource{{A: [4]byte{1, 0, 0, 0}}}, + name: "one-ip-matches", + appDomains: []string{"example.com"}, + domain: "example.com.", + v4Addrs: []*dnsmessage.AResource{{A: [4]byte{1, 0, 0, 0}}}, // these are 'expected' because they are the beginning of the provided pools wantByMagicIP: map[netip.Addr]addrs{ netip.MustParseAddr("100.64.0.0"): { @@ -440,9 +928,35 @@ func TestMapDNSResponse(t *testing.T) { }, }, { - name: "multiple-ip-matches", - domain: "example.com.", - addrs: []dnsmessage.AResource{ + name: "v6-ip-matches", + appDomains: []string{"example.com"}, + domain: "example.com.", + v6Addrs: []*dnsmessage.AAAAResource{ + {AAAA: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}}, + {AAAA: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}}, + }, + wantByMagicIP: map[netip.Addr]addrs{ + netip.MustParseAddr("fd7a:115c:a1e0:a99c::"): { + domain: "example.com.", + dst: netip.MustParseAddr("::1"), + magic: netip.MustParseAddr("fd7a:115c:a1e0:a99c:0::"), + transit: netip.MustParseAddr("fd7a:115c:a1e0:a99c:40::"), + app: "app1", + }, + netip.MustParseAddr("fd7a:115c:a1e0:a99c::1"): { + domain: "example.com.", + dst: netip.MustParseAddr("::2"), + magic: netip.MustParseAddr("fd7a:115c:a1e0:a99c:0::1"), + transit: netip.MustParseAddr("fd7a:115c:a1e0:a99c:40::1"), + app: "app1", + }, + }, + }, + { + name: "multiple-ip-matches", + appDomains: []string{"example.com"}, + domain: "example.com.", + v4Addrs: []*dnsmessage.AResource{ {A: [4]byte{1, 0, 0, 0}}, {A: [4]byte{2, 0, 0, 0}}, }, @@ -464,61 +978,1272 @@ func TestMapDNSResponse(t *testing.T) { }, }, { - name: "no-domain-match", - domain: "x.example.com.", - addrs: []dnsmessage.AResource{ + name: "no-domain-match", + appDomains: []string{"foo.example.com"}, + domain: "bad.example.com.", + v4Addrs: []*dnsmessage.AResource{ {A: [4]byte{1, 0, 0, 0}}, {A: [4]byte{2, 0, 0, 0}}, }, }, + { + name: "no-rewrite-self-routed-domain", + appDomains: []string{"example.com"}, + domain: "example.com.", + v4Addrs: []*dnsmessage.AResource{{A: [4]byte{1, 0, 0, 0}}}, + selfTags: []string{"tag:woo"}, + isEligibleConnector: true, + }, + { + name: "rewrite-tagged-but-not-eligible-connector", + appDomains: []string{"example.com"}, + domain: "example.com.", + v4Addrs: []*dnsmessage.AResource{{A: [4]byte{1, 0, 0, 0}}}, + selfTags: []string{"tag:woo"}, + // isEligibleConnector is false: tag matches but prefs not set, + // so DNS response should be rewritten normally. + wantByMagicIP: map[netip.Addr]addrs{ + netip.MustParseAddr("100.64.0.0"): { + domain: "example.com.", + dst: netip.MustParseAddr("1.0.0.0"), + magic: netip.MustParseAddr("100.64.0.0"), + transit: netip.MustParseAddr("100.64.0.40"), + app: "app1", + }, + }, + }, + { + name: "rewrite-eligible-connector-no-matching-tag", + appDomains: []string{"example.com"}, + domain: "example.com.", + v4Addrs: []*dnsmessage.AResource{{A: [4]byte{1, 0, 0, 0}}}, + selfTags: []string{"tag:unrelated"}, + isEligibleConnector: true, + // isEligibleConnector is true but tag doesn't match the app, + // so DNS response should be rewritten normally. + wantByMagicIP: map[netip.Addr]addrs{ + netip.MustParseAddr("100.64.0.0"): { + domain: "example.com.", + dst: netip.MustParseAddr("1.0.0.0"), + magic: netip.MustParseAddr("100.64.0.0"), + transit: netip.MustParseAddr("100.64.0.40"), + app: "app1", + }, + }, + }, + { + name: "subdomain-matches-wildcard", + appDomains: []string{"*.example.com"}, + domain: "sub.example.com.", + v4Addrs: []*dnsmessage.AResource{{A: [4]byte{1, 0, 0, 0}}}, + // these are 'expected' because they are the beginning of the provided pools + wantByMagicIP: map[netip.Addr]addrs{ + netip.MustParseAddr("100.64.0.0"): { + domain: "sub.example.com.", + dst: netip.MustParseAddr("1.0.0.0"), + magic: netip.MustParseAddr("100.64.0.0"), + transit: netip.MustParseAddr("100.64.0.40"), + app: "app1", + }, + }, + }, + { + name: "exact-subdomain-matches", + appDomains: []string{"example.com", "sub.example.com"}, + domain: "sub.example.com.", + v4Addrs: []*dnsmessage.AResource{{A: [4]byte{1, 0, 0, 0}}}, + // these are 'expected' because they are the beginning of the provided pools + wantByMagicIP: map[netip.Addr]addrs{ + netip.MustParseAddr("100.64.0.0"): { + domain: "sub.example.com.", + dst: netip.MustParseAddr("1.0.0.0"), + magic: netip.MustParseAddr("100.64.0.0"), + transit: netip.MustParseAddr("100.64.0.40"), + app: "app1", + }, + }, + }, + { + name: "wildcard-subdomain-matches-subdomain", + appDomains: []string{"example.com", "*.sub.example.com"}, + domain: "a.sub.example.com.", + v4Addrs: []*dnsmessage.AResource{{A: [4]byte{1, 0, 0, 0}}}, + // these are 'expected' because they are the beginning of the provided pools + wantByMagicIP: map[netip.Addr]addrs{ + netip.MustParseAddr("100.64.0.0"): { + domain: "a.sub.example.com.", + dst: netip.MustParseAddr("1.0.0.0"), + magic: netip.MustParseAddr("100.64.0.0"), + transit: netip.MustParseAddr("100.64.0.40"), + app: "app1", + }, + }, + }, } { t.Run(tt.name, func(t *testing.T) { - dnsResp := makeDNSResponse(tt.domain, tt.addrs) - sn := makeSelfNode(t, appctype.Conn25Attr{ - Name: "app1", - Connectors: []string{"tag:woo"}, - Domains: []string{"example.com"}, - MagicIPPool: []netipx.IPRange{rangeFrom("0", "10"), rangeFrom("20", "30")}, - TransitIPPool: []netipx.IPRange{rangeFrom("40", "50")}, - }, []string{}) + var dnsResp []byte + if len(tt.v4Addrs) > 0 { + dnsResp = makeDNSResponse(t, tt.domain, tt.v4Addrs) + } else { + dnsResp = makeV6DNSResponse(t, tt.domain, tt.v6Addrs) + } + sn := makeSelfNode(t, []appctype.Conn25Attr{{ + Name: "app1", + Connectors: []string{"tag:woo"}, + Domains: tt.appDomains, + V4MagicIPPool: []netipx.IPRange{v4RangeFrom("0", "10"), v4RangeFrom("20", "30")}, + V6MagicIPPool: []netipx.IPRange{v6RangeFrom("0", "10"), v6RangeFrom("20", "30")}, + V4TransitIPPool: []netipx.IPRange{v4RangeFrom("40", "50")}, + V6TransitIPPool: []netipx.IPRange{v6RangeFrom("40", "50")}, + }}, tt.selfTags) + c := newConn25(logger.Discard) - c.reconfig(sn) + cfg := mustConfig(t, sn) + c.reconfig(cfg) + c.prefsAdvertiseConnector.Store(tt.isEligibleConnector) - bs := c.mapDNSResponse(dnsResp) - if !reflect.DeepEqual(dnsResp, bs) { - t.Fatal("shouldn't be changing the bytes (yet)") - } - if diff := cmp.Diff(tt.wantByMagicIP, c.client.assignments.byMagicIP, cmpopts.EquateComparable(addrs{}, netip.Addr{})); diff != "" { + c.mapDNSResponse(dnsResp) + if diff := cmp.Diff( + tt.wantByMagicIP, + c.client.assignments.byMagicIP, + cmp.AllowUnexported(addrs{}), + cmpopts.IgnoreFields(addrs{}, "expiresAt"), + cmpopts.EquateComparable(netip.Addr{}), + ); diff != "" { t.Errorf("byMagicIP diff (-want, +got):\n%s", diff) } }) } } +func TestNormalizedDNSNames(t *testing.T) { + tests := []struct { + name string + domain string + want dnsname.FQDN + }{ + {name: "no-change", domain: "example.com.", want: "example.com."}, + {name: "mixed-case", domain: "eXAmPle.COM", want: "example.com."}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := normalizeDNSName(tt.domain) + if err != nil { + t.Errorf("unexpected error %v", err) + } + if got != tt.want { + t.Errorf("Unexpected result, want %q, got %q", tt.want, got) + } + }) + } +} + func TestReserveAddressesDeduplicated(t *testing.T) { - c := newConn25(logger.Discard) - c.client.magicIPPool = newIPPool(mustIPSetFromPrefix("100.64.0.0/24")) - c.client.transitIPPool = newIPPool(mustIPSetFromPrefix("169.254.0.0/24")) - c.client.config.appsByDomain = map[dnsname.FQDN][]string{"example.com.": {"a"}} + for _, tt := range []struct { + name string + dst netip.Addr + }{ + { + name: "v4", + dst: netip.MustParseAddr("0.0.0.1"), + }, + { + name: "v6", + dst: netip.MustParseAddr("::1"), + }, + } { + t.Run(tt.name, func(t *testing.T) { + const appName = "a" + conn25 := newConn25(t.Logf) + c := conn25.client + c.v4MagicIPPool = newIPPool(mustIPSetFromPrefix("100.64.0.0/24")) + c.v6MagicIPPool = newIPPool(mustIPSetFromPrefix("fd7a:115c:a1e0:a99c:0100::/80")) + c.v4TransitIPPool = newIPPool(mustIPSetFromPrefix("169.254.0.0/24")) + c.v6TransitIPPool = newIPPool(mustIPSetFromPrefix("fd7a:115c:a1e0:a99c:0200::/80")) - dst := netip.MustParseAddr("0.0.0.1") - first, err := c.client.reserveAddresses("example.com.", dst) - if err != nil { + first, err := c.reserveAddresses(appName, "example.com.", tt.dst) + if err != nil { + t.Fatal(err) + } + + second, err := c.reserveAddresses(appName, "example.com.", tt.dst) + if err != nil { + t.Fatal(err) + } + + if first.magic != second.magic { + t.Errorf("expected same magic addrs on repeated call, got first=%v second=%v", first.magic, second.magic) + } + if got := len(c.assignments.byMagicIP); got != 1 { + t.Errorf("want 1 entry in byMagicIP, got %d", got) + } + if got := len(c.assignments.byDomainDst); got != 1 { + t.Errorf("want 1 entry in byDomainDst, got %d", got) + } + + }) + } +} + +type testNodeBackend struct { + ipnext.NodeBackend + peers []tailcfg.NodeView + peerAPIURL string // should be per peer but there's only one peer in our test so this is ok for now +} + +func (nb *testNodeBackend) AppendMatchingPeers(base []tailcfg.NodeView, pred func(tailcfg.NodeView) bool) []tailcfg.NodeView { + for _, p := range nb.peers { + if pred(p) { + base = append(base, p) + } + } + return base +} + +func (nb *testNodeBackend) PeerHasPeerAPI(p tailcfg.NodeView) bool { + return true +} + +func (nb *testNodeBackend) PeerAPIBase(p tailcfg.NodeView) string { + return nb.peerAPIURL +} + +type testProfileServices struct { + ipnext.ProfileServices + prefs ipn.PrefsView +} + +func (p *testProfileServices) CurrentPrefs() ipn.PrefsView { return p.prefs } +func (p *testProfileServices) CurrentProfileState() (ipn.LoginProfileView, ipn.PrefsView) { + return ipn.LoginProfileView{}, p.prefs +} + +type testHost struct { + ipnext.Host + nb ipnext.NodeBackend + hooks ipnext.Hooks + prefs ipn.PrefsView + authReconfigAsync func() +} + +func (h *testHost) NodeBackend() ipnext.NodeBackend { return h.nb } +func (h *testHost) Hooks() *ipnext.Hooks { return &h.hooks } +func (h *testHost) Profiles() ipnext.ProfileServices { return &testProfileServices{prefs: h.prefs} } +func (h *testHost) AuthReconfigAsync() { h.authReconfigAsync() } + +type testSafeBackend struct { + ipnext.SafeBackend + sys *tsd.System +} + +func newTestSafeBackend() *testSafeBackend { + sb := &testSafeBackend{} + sys := &tsd.System{} + sys.Dialer.Set(&tsdial.Dialer{Logf: logger.Discard}) + sys.DNSManager.Set(&dns.Manager{}) + sys.Tun.Set(&tstun.Wrapper{}) + sb.sys = sys + return sb +} + +func (b *testSafeBackend) Sys() *tsd.System { return b.sys } + +// TestAddressAssignmentIsHandled tests that after enqueueAddress has been called +// we handle the assignment asynchronously by: +// - making a peerapi request to a peer. +// - calling AuthReconfigAsync on the host. +func TestAddressAssignmentIsHandled(t *testing.T) { + // make a fake peer to test against + received := make(chan ConnectorTransitIPRequest, 1) + peersAPI := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v0/connector/transit-ip" { + http.Error(w, "unexpected path", http.StatusNotFound) + return + } + var req ConnectorTransitIPRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "bad body", http.StatusBadRequest) + return + } + received <- req + resp := ConnectorTransitIPResponse{ + TransitIPs: []TransitIPResponse{{Code: OK}}, + } + json.NewEncoder(w).Encode(resp) + })) + defer peersAPI.Close() + + connectorPeer := (&tailcfg.Node{ + ID: tailcfg.NodeID(1), + Tags: []string{"tag:woo"}, + Hostinfo: (&tailcfg.Hostinfo{AppConnector: opt.NewBool(true)}).View(), + Key: key.NodePublicFromRaw32(mem.B([]byte{0: 0xff, 1: 0xff, 31: 0x01})), + }).View() + + ext := &extension{ + conn25: newConn25(logger.Discard), + backend: newTestSafeBackend(), + } + authReconfigAsyncCalled := make(chan struct{}, 1) + if err := ext.Init(&testHost{ + nb: &testNodeBackend{ + peers: []tailcfg.NodeView{connectorPeer}, + peerAPIURL: peersAPI.URL, + }, + prefs: testPrefsNotConnector, + authReconfigAsync: func() { + authReconfigAsyncCalled <- struct{}{} + }, + }); err != nil { t.Fatal(err) } + defer ext.Shutdown() + + sn := makeSelfNode(t, []appctype.Conn25Attr{{ + Name: "app1", + Connectors: []string{"tag:woo"}, + Domains: []string{"example.com"}, + }}, []string{}) + + cfg := mustConfig(t, sn) + ext.conn25.reconfig(cfg) + + as := addrs{ + dst: netip.MustParseAddr("1.2.3.4"), + magic: netip.MustParseAddr("100.64.0.0"), + transit: netip.MustParseAddr("169.254.0.1"), + domain: "example.com.", + app: "app1", + } + if err := ext.conn25.client.assignments.insert(as); err != nil { + t.Fatalf("error inserting address assignments: %v", err) + } + ext.conn25.client.enqueueAddressAssignment(as) + + select { + case got := <-received: + if len(got.TransitIPs) != 1 { + t.Fatalf("want 1 TransitIP in request, got %d", len(got.TransitIPs)) + } + tip := got.TransitIPs[0] + if tip.TransitIP != as.transit { + t.Errorf("TransitIP: got %v, want %v", tip.TransitIP, as.transit) + } + if tip.DestinationIP != as.dst { + t.Errorf("DestinationIP: got %v, want %v", tip.DestinationIP, as.dst) + } + if tip.App != as.app { + t.Errorf("App: got %q, want %q", tip.App, as.app) + } + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for connector to receive request") + } + select { + case <-authReconfigAsyncCalled: + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for AuthReconfigAsync to be called") + } +} - second, err := c.client.reserveAddresses("example.com.", dst) +func parseResponse(t *testing.T, buf []byte) ([]dnsmessage.Resource, []dnsmessage.Resource) { + t.Helper() + var p dnsmessage.Parser + header, err := p.Start(buf) + if err != nil { + t.Fatalf("parsing DNS response: %v", err) + } + if header.RCode != dnsmessage.RCodeSuccess { + t.Fatalf("RCode want: %v, got: %v", dnsmessage.RCodeSuccess, header.RCode) + } + if err := p.SkipAllQuestions(); err != nil { + t.Fatalf("skipping questions: %v", err) + } + answers, err := p.AllAnswers() + if err != nil { + t.Fatalf("reading answers: %v", err) + } + if err := p.SkipAllAuthorities(); err != nil { + t.Fatalf("skipping questions: %v", err) + } + additionals, err := p.AllAdditionals() if err != nil { + t.Fatalf("reading additionals: %v", err) + } + return answers, additionals +} + +func TestMapDNSResponseRewritesResponses(t *testing.T) { + configuredDomain := "example.com" + domainName := configuredDomain + "." + dnsMessageName := dnsmessage.MustNewName(domainName) + sn := makeSelfNode(t, []appctype.Conn25Attr{{ + Name: "app1", + Connectors: []string{"tag:connector"}, + Domains: []string{configuredDomain}, + V4MagicIPPool: []netipx.IPRange{v4RangeFrom("0", "10")}, + V4TransitIPPool: []netipx.IPRange{v4RangeFrom("40", "50")}, + V6MagicIPPool: []netipx.IPRange{netipx.IPRangeFrom(netip.MustParseAddr("2606:4700::6812:100"), netip.MustParseAddr("2606:4700::6812:1ff"))}, + V6TransitIPPool: []netipx.IPRange{netipx.IPRangeFrom(netip.MustParseAddr("2606:4700::6813:100"), netip.MustParseAddr("2606:4700::6813:1ff"))}, + }}, []string{}) + + cfg := mustConfig(t, sn) + + compareToRecords := func(t *testing.T, resources []dnsmessage.Resource, want []netip.Addr) { + t.Helper() + var got []netip.Addr + for _, r := range resources { + if b, ok := r.Body.(*dnsmessage.AResource); ok { + got = append(got, netip.AddrFrom4(b.A)) + } else if b, ok := r.Body.(*dnsmessage.AAAAResource); ok { + got = append(got, netip.AddrFrom16(b.AAAA)) + } + } + if diff := cmp.Diff(want, got, cmpopts.EquateComparable(netip.Addr{})); diff != "" { + t.Fatalf("A/AAAA records mismatch (-want +got):\n%s", diff) + } + } + + assertParsesToAnswers := func(want []netip.Addr) func(t *testing.T, bs []byte) { + return func(t *testing.T, bs []byte) { + t.Helper() + answers, _ := parseResponse(t, bs) + compareToRecords(t, answers, want) + } + } + + assertParsesToAdditionals := func(want []netip.Addr) func(t *testing.T, bs []byte) { + return func(t *testing.T, bs []byte) { + t.Helper() + _, additionals := parseResponse(t, bs) + compareToRecords(t, additionals, want) + } + } + + assertBytes := func(want []byte) func(t *testing.T, bs []byte) { + return func(t *testing.T, bs []byte) { + t.Helper() + if diff := cmp.Diff(want, bs); diff != "" { + t.Fatalf("bytes mismatch (-want +got):\n%s", diff) + } + } + } + assertServFail := func(t *testing.T, bs []byte) { + var p dnsmessage.Parser + header, err := p.Start(bs) + if err != nil { + t.Fatalf("parsing DNS response: %v", err) + } + if header.RCode != dnsmessage.RCodeServerFailure { + t.Fatalf("RCode want: %v, got: %v", dnsmessage.RCodeServerFailure, header.RCode) + } + } + + ipv6ResponseUnhandledDomain := makeV6DNSResponse(t, "tailscale.com.", []*dnsmessage.AAAAResource{ + {AAAA: netip.MustParseAddr("2606:4700::6812:1a78").As16()}, + {AAAA: netip.MustParseAddr("2606:4700::6812:1b78").As16()}, + }) + + ipv4ResponseUnhandledDomain := makeDNSResponse(t, "tailscale.com.", []*dnsmessage.AResource{ + {A: netip.MustParseAddr("1.2.3.4").As4()}, + {A: netip.MustParseAddr("5.6.7.8").As4()}, + }) + + nonINETQuestionResp := makeDNSResponseForSections(t, []dnsmessage.Question{ + { + Name: dnsMessageName, + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassCHAOS, + }, + }, nil, nil) + + for _, tt := range []struct { + name string + toMap []byte + assertFx func(*testing.T, []byte) + }{ + { + name: "unparseable", + toMap: []byte{1, 2, 3, 4}, + assertFx: assertBytes([]byte{1, 2, 3, 4}), + }, + { + name: "maps-multi-typea-answers", + toMap: makeDNSResponse(t, domainName, []*dnsmessage.AResource{ + {A: netip.MustParseAddr("1.2.3.4").As4()}, + {A: netip.MustParseAddr("5.6.7.8").As4()}, + }), + assertFx: assertParsesToAnswers( + []netip.Addr{ + netip.MustParseAddr("100.64.0.0"), + netip.MustParseAddr("100.64.0.1"), + }, + ), + }, + { + name: "ipv6-multiple", + toMap: makeV6DNSResponse(t, domainName, []*dnsmessage.AAAAResource{ + {AAAA: netip.MustParseAddr("2606:4700::6812:1a78").As16()}, + {AAAA: netip.MustParseAddr("2606:4700::6812:1b78").As16()}, + }), + assertFx: assertParsesToAnswers( + []netip.Addr{ + netip.MustParseAddr("2606:4700::6812:100"), + netip.MustParseAddr("2606:4700::6812:101"), + }, + ), + }, + { + name: "not-our-domain", + toMap: ipv4ResponseUnhandledDomain, + assertFx: assertBytes(ipv4ResponseUnhandledDomain), + }, + { + name: "ipv6-not-our-domain", + toMap: ipv6ResponseUnhandledDomain, + assertFx: assertBytes(ipv6ResponseUnhandledDomain), + }, + { + name: "case-insensitive", + toMap: makeDNSResponse(t, "eXample.com.", []*dnsmessage.AResource{ + {A: netip.MustParseAddr("1.2.3.4").As4()}, + {A: netip.MustParseAddr("5.6.7.8").As4()}, + }), + assertFx: assertParsesToAnswers( + []netip.Addr{ + netip.MustParseAddr("100.64.0.0"), + netip.MustParseAddr("100.64.0.1"), + }, + ), + }, + { + name: "unhandled-keeps-additional-section", + toMap: makeDNSResponse(t, "tailscale.com.", []*dnsmessage.AResource{ + {A: netip.MustParseAddr("1.2.3.4").As4()}, + {A: netip.MustParseAddr("5.6.7.8").As4()}, + }), + assertFx: assertParsesToAdditionals( + // additionals are added in makeDNSResponse + []netip.Addr{ + netip.MustParseAddr("9.9.9.9"), + }, + ), + }, + { + name: "handled-strips-additional-section", + toMap: makeDNSResponse(t, domainName, []*dnsmessage.AResource{ + {A: netip.MustParseAddr("1.2.3.4").As4()}, + {A: netip.MustParseAddr("5.6.7.8").As4()}, + }), + assertFx: assertParsesToAdditionals(nil), + }, + { + name: "servfail-when-we-should-handle-but-cant", + // produced by + // makeDNSResponse(t, domainName, []*dnsmessage.AResource{{A: netip.MustParseAddr("1.2.3.4").As4()}}) + // and then taking 17 bytes off the end. So that the parsing of it breaks after we have decided we should handle it. + // Frozen like this so that it doesn't depend on the implementation of dnsmessage. + toMap: []byte{0, 1, 132, 0, 0, 1, 0, 1, 0, 0, 0, 1, 7, 101, 120, 97, 109, 112, 108, 101, 3, 99, 111, 109, 0, 0, 1, 0, 1, 192, 12, 0, 1, 0, 1, 0, 0, 0, 0, 0, 4, 1, 2, 3}, + assertFx: assertServFail, + }, + { + name: "not-inet-question", + toMap: nonINETQuestionResp, + assertFx: assertBytes(nonINETQuestionResp), + }, + { + name: "not-inet-answer", + toMap: makeDNSResponseForSections(t, + []dnsmessage.Question{ + { + Name: dnsMessageName, + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }, + }, + []dnsmessage.Resource{ + { + Header: dnsmessage.ResourceHeader{ + Name: dnsMessageName, + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassCHAOS, + }, + Body: &dnsmessage.AResource{A: netip.MustParseAddr("1.2.3.4").As4()}, + }, + }, + nil, + ), + assertFx: assertParsesToAnswers(nil), + }, + { + name: "answer-domain-mismatch", + toMap: makeDNSResponseForSections(t, + []dnsmessage.Question{ + { + Name: dnsMessageName, + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }, + }, + []dnsmessage.Resource{ + { + Header: dnsmessage.ResourceHeader{ + Name: dnsmessage.MustNewName("tailscale.com."), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }, + Body: &dnsmessage.AResource{A: netip.MustParseAddr("1.2.3.4").As4()}, + }, + }, + nil, + ), + assertFx: assertParsesToAnswers(nil), + }, + { + name: "answer-type-mismatch-want-v4", + toMap: makeDNSResponseForSections(t, + []dnsmessage.Question{ + { + Name: dnsMessageName, + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }, + }, + []dnsmessage.Resource{ + { + Header: dnsmessage.ResourceHeader{ + Name: dnsMessageName, + Type: dnsmessage.TypeAAAA, + Class: dnsmessage.ClassINET, + }, + Body: &dnsmessage.AAAAResource{AAAA: netip.MustParseAddr("1.2.3.4").As16()}, + }, + { + Header: dnsmessage.ResourceHeader{ + Name: dnsMessageName, + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }, + Body: &dnsmessage.AResource{A: netip.MustParseAddr("5.6.7.8").As4()}, + }, + }, + nil, + ), + assertFx: assertParsesToAnswers([]netip.Addr{netip.MustParseAddr("100.64.0.0")}), + }, + { + name: "answer-type-mismatch-want-v6", + toMap: makeDNSResponseForSections(t, + []dnsmessage.Question{ + { + Name: dnsMessageName, + Type: dnsmessage.TypeAAAA, + Class: dnsmessage.ClassINET, + }, + }, + []dnsmessage.Resource{ + { + Header: dnsmessage.ResourceHeader{ + Name: dnsMessageName, + Type: dnsmessage.TypeAAAA, + Class: dnsmessage.ClassINET, + }, + Body: &dnsmessage.AAAAResource{AAAA: netip.MustParseAddr("1.2.3.4").As16()}, + }, + { + Header: dnsmessage.ResourceHeader{ + Name: dnsMessageName, + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }, + Body: &dnsmessage.AResource{A: netip.MustParseAddr("5.6.7.8").As4()}, + }, + }, + nil, + ), + assertFx: assertParsesToAnswers([]netip.Addr{netip.MustParseAddr("2606:4700::6812:100")}), + }, + } { + t.Run(tt.name, func(t *testing.T) { + c := newConn25(logger.Discard) + c.reconfig(cfg) + bs := c.mapDNSResponse(tt.toMap) + tt.assertFx(t, bs) + }) + } +} + +func TestHandleAddressAssignmentStoresTransitIPs(t *testing.T) { + // make a fake peer API to test against, for all peers + peersAPI := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v0/connector/transit-ip" { + http.Error(w, "unexpected path", http.StatusNotFound) + return + } + var req ConnectorTransitIPRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "bad body", http.StatusBadRequest) + return + } + resp := ConnectorTransitIPResponse{ + TransitIPs: []TransitIPResponse{{Code: OK}}, + } + json.NewEncoder(w).Encode(resp) + })) + defer peersAPI.Close() + + connectorPeers := []tailcfg.NodeView{ + (&tailcfg.Node{ + ID: tailcfg.NodeID(1), + Tags: []string{"tag:woo"}, + Hostinfo: (&tailcfg.Hostinfo{AppConnector: opt.NewBool(true)}).View(), + Key: key.NodePublicFromRaw32(mem.B([]byte{0: 0xff, 31: 0x01})), + }).View(), + (&tailcfg.Node{ + ID: tailcfg.NodeID(2), + Tags: []string{"tag:hoo"}, + Hostinfo: (&tailcfg.Hostinfo{AppConnector: opt.NewBool(true)}).View(), + Key: key.NodePublicFromRaw32(mem.B([]byte{0: 0xff, 31: 0x02})), + }).View(), + } + + ext := &extension{ + conn25: newConn25(logger.Discard), + backend: newTestSafeBackend(), + } + authReconfigAsyncCalled := make(chan struct{}, 1) + if err := ext.Init(&testHost{ + nb: &testNodeBackend{ + peers: connectorPeers, + peerAPIURL: peersAPI.URL, + }, + prefs: testPrefsNotConnector, + authReconfigAsync: func() { + authReconfigAsyncCalled <- struct{}{} + }, + }); err != nil { t.Fatal(err) } + defer ext.Shutdown() + + sn := makeSelfNode(t, []appctype.Conn25Attr{ + { + Name: "app1", + Connectors: []string{"tag:woo"}, + Domains: []string{"woo.example.com"}, + }, + { + Name: "app2", + Connectors: []string{"tag:hoo"}, + Domains: []string{"hoo.example.com"}, + }, + }, []string{}) + + cfg := mustConfig(t, sn) + ext.conn25.reconfig(cfg) + + type lookup struct { + connKey key.NodePublic + expectedIPs []netip.Prefix + expectedOk bool + } + + transitIPs := []netip.Prefix{ + netip.MustParsePrefix("169.254.0.1/32"), + netip.MustParsePrefix("169.254.0.2/32"), + netip.MustParsePrefix("169.254.0.3/32"), + } + // Each step performs an insert on the provided addrs + // and then does the lookups. + steps := []struct { + name string + as addrs + lookups []lookup + }{ + { + name: "step-1-conn1-tip1", + as: addrs{ + dst: netip.MustParseAddr("1.2.3.1"), + magic: netip.MustParseAddr("100.64.0.1"), + transit: transitIPs[0].Addr(), + domain: "woo.example.com.", + app: "app1", + }, + lookups: []lookup{ + { + connKey: connectorPeers[0].Key(), + expectedIPs: []netip.Prefix{ + transitIPs[0], + }, + expectedOk: true, + }, + { + connKey: connectorPeers[1].Key(), + expectedIPs: nil, + expectedOk: false, + }, + }, + }, + { + name: "step-2-conn1-tip2", + as: addrs{ + dst: netip.MustParseAddr("1.2.3.2"), + magic: netip.MustParseAddr("100.64.0.2"), + transit: transitIPs[1].Addr(), + domain: "woo.example.com.", + app: "app1", + }, + lookups: []lookup{ + { + connKey: connectorPeers[0].Key(), + expectedIPs: []netip.Prefix{ + transitIPs[0], + transitIPs[1], + }, + expectedOk: true, + }, + }, + }, + { + name: "step-3-conn2-tip1", + as: addrs{ + dst: netip.MustParseAddr("1.2.3.3"), + magic: netip.MustParseAddr("100.64.0.3"), + transit: transitIPs[2].Addr(), + domain: "hoo.example.com.", + app: "app2", + }, + lookups: []lookup{ + { + connKey: connectorPeers[0].Key(), + expectedIPs: []netip.Prefix{ + transitIPs[0], + transitIPs[1], + }, + expectedOk: true, + }, + { + connKey: connectorPeers[1].Key(), + expectedIPs: []netip.Prefix{ + transitIPs[2], + }, + expectedOk: true, + }, + }, + }, + } + + for _, tt := range steps { + t.Run(tt.name, func(t *testing.T) { + // Add and enqueue the addrs, and then wait for the send to complete + // (as indicated by authReconfig being called). + if err := ext.conn25.client.assignments.insert(tt.as); err != nil { + t.Fatalf("error inserting address assignment: %v", err) + } + if err := ext.conn25.client.enqueueAddressAssignment(tt.as); err != nil { + t.Fatalf("error enqueuing address assignment: %v", err) + } + select { + case <-authReconfigAsyncCalled: + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for AuthReconfigAsync to be called") + } + + // Check that each of the lookups behaves as expected + for i, lu := range tt.lookups { + got, ok := ext.conn25.client.lookupTransitIPsByConnKey(lu.connKey) + if ok != lu.expectedOk { + t.Fatalf("unexpected ok result at index %d wanted %v, got %v", i, lu.expectedOk, ok) + } + slices.SortFunc(got, func(a, b netip.Prefix) int { return a.Compare(b) }) + if diff := cmp.Diff(lu.expectedIPs, got, cmpopts.EquateComparable(netip.Prefix{})); diff != "" { + t.Fatalf("transit IPs mismatch at index %d, (-want +got):\n%s", i, diff) + } + } + }) + } +} + +func TestTransitIPConnMapping(t *testing.T) { + conn25 := newConn25(t.Logf) + + as := addrs{ + dst: netip.MustParseAddr("1.2.3.1"), + magic: netip.MustParseAddr("100.64.0.1"), + transit: netip.MustParseAddr("169.254.0.1"), + domain: "woo.example.com.", + app: "app1", + } + + connectorPeers := []tailcfg.NodeView{ + (&tailcfg.Node{ + ID: tailcfg.NodeID(0), + Tags: []string{"tag:woo"}, + Hostinfo: (&tailcfg.Hostinfo{AppConnector: opt.NewBool(true)}).View(), + Key: key.NodePublic{}, + }).View(), + (&tailcfg.Node{ + ID: tailcfg.NodeID(2), + Tags: []string{"tag:hoo"}, + Hostinfo: (&tailcfg.Hostinfo{AppConnector: opt.NewBool(true)}).View(), + Key: key.NodePublicFromRaw32(mem.B([]byte{0: 0xff, 31: 0x02})), + }).View(), + } + + // Adding a transit IP that isn't known should fail + if err := conn25.client.addTransitIPForConnector(as.transit, connectorPeers[1]); err == nil { + t.Error("adding an unknown transit IP should fail") + } + + // Insert the address assignments + conn25.client.assignments.insert(as) + + // Adding a transit IP for a node with an unset key should fail + if err := conn25.client.addTransitIPForConnector(as.transit, connectorPeers[0]); err == nil { + t.Error("adding an transit IP mapping for a connector with a zero key should fail") + } + // Adding a transit IP that is known should succeed + if err := conn25.client.addTransitIPForConnector(as.transit, connectorPeers[1]); err != nil { + t.Errorf("unexpected error for first time add: %v", err) + } + // But doing it again should fail + if err := conn25.client.addTransitIPForConnector(as.transit, connectorPeers[1]); err == nil { + t.Error("adding a duplicate transitIP for a connector should fail") + } +} - if first != second { - t.Errorf("expected same addrs on repeated call, got first=%v second=%v", first, second) +func TestClientTransitIPForMagicIP(t *testing.T) { + sn := makeSelfNode(t, []appctype.Conn25Attr{{ + V4MagicIPPool: []netipx.IPRange{v4RangeFrom("0", "10")}, // 100.64.0.0 - 100.64.0.10 + V6MagicIPPool: []netipx.IPRange{v6RangeFrom("0", "10")}, + }}, []string{}) + cfg := mustConfig(t, sn) + + mappedMip := netip.MustParseAddr("100.64.0.0") + mappedTip := netip.MustParseAddr("169.0.0.0") + unmappedMip := netip.MustParseAddr("100.64.0.1") + nonMip := netip.MustParseAddr("100.64.0.11") + dst := netip.MustParseAddr("0.0.0.1") + + v6MappedMip := netip.MustParseAddr("fd7a:115c:a1e0:a99c:0::") + v6MappedTip := netip.MustParseAddr("fd7a:115c:a1e0:a99c:100::") + v6UnmappedMip := netip.MustParseAddr("fd7a:115c:a1e0:a99c:1::") + v6NonMip := netip.MustParseAddr("fd7a:115c:a1e0:a99c:11::") + v6Dst := netip.MustParseAddr("::1") + + for _, tt := range []struct { + name string + mip netip.Addr + wantTip netip.Addr + wantErr error + }{ + { + name: "not-a-magic-ip", + mip: nonMip, + wantTip: netip.Addr{}, + wantErr: nil, + }, + { + name: "unmapped-magic-ip", + mip: unmappedMip, + wantTip: netip.Addr{}, + wantErr: ErrUnmappedMagicIP, + }, + { + name: "mapped-magic-ip", + mip: mappedMip, + wantTip: mappedTip, + wantErr: nil, + }, + { + name: "v6-not-magic", + mip: v6NonMip, + wantTip: netip.Addr{}, + wantErr: nil, + }, + { + name: "v6-unmapped-magic-ip", + mip: v6UnmappedMip, + wantTip: netip.Addr{}, + wantErr: ErrUnmappedMagicIP, + }, + { + name: "v6-mapped-magic-ip", + mip: v6MappedMip, + wantTip: v6MappedTip, + wantErr: nil, + }, + } { + t.Run(tt.name, func(t *testing.T) { + c := newConn25(t.Logf) + c.reconfig(cfg) + + if err := c.client.assignments.insert(addrs{ + magic: mappedMip, + transit: mappedTip, + dst: dst, + }); err != nil { + t.Fatal(err) + } + if err := c.client.assignments.insert(addrs{ + magic: v6MappedMip, + transit: v6MappedTip, + dst: v6Dst, + }); err != nil { + t.Fatal(err) + } + tip, err := c.ClientTransitIPForMagicIP(tt.mip) + if tip != tt.wantTip { + t.Fatalf("checking transit ip: want %v, got %v", tt.wantTip, tip) + } + if err != tt.wantErr { + t.Fatalf("checking error: want %v, got %v", tt.wantErr, err) + } + }) + } +} + +func TestConnectorRealIPForTransitIPConnection(t *testing.T) { + sn := makeSelfNode(t, []appctype.Conn25Attr{{ + V4TransitIPPool: []netipx.IPRange{v4RangeFrom("40", "50")}, // 100.64.0.40 - 100.64.0.50 + }}, []string{}) + cfg := mustConfig(t, sn) + + mappedSrc := netip.MustParseAddr("100.0.0.1") + unmappedSrc := netip.MustParseAddr("100.0.0.2") + mappedTip := netip.MustParseAddr("100.64.0.41") + unmappedTip := netip.MustParseAddr("100.64.0.42") + nonTip := netip.MustParseAddr("100.0.0.3") + mappedMip := netip.MustParseAddr("100.64.0.1") + for _, tt := range []struct { + name string + src netip.Addr + tip netip.Addr + wantMip netip.Addr + wantErr error + }{ + { + name: "not-a-transit-ip-unmapped-src", + src: unmappedSrc, + tip: nonTip, + wantMip: netip.Addr{}, + wantErr: nil, + }, + { + name: "not-a-transit-ip-mapped-src", + src: mappedSrc, + tip: nonTip, + wantMip: netip.Addr{}, + wantErr: nil, + }, + { + name: "unmapped-src-transit-ip", + src: unmappedSrc, + tip: unmappedTip, + wantMip: netip.Addr{}, + wantErr: ErrUnmappedSrcAndTransitIP, + }, + { + name: "unmapped-tip-transit-ip", + src: mappedSrc, + tip: unmappedTip, + wantMip: netip.Addr{}, + wantErr: ErrUnmappedSrcAndTransitIP, + }, + { + name: "mapped-src-and-transit-ip", + src: mappedSrc, + tip: mappedTip, + wantMip: mappedMip, + wantErr: nil, + }, + } { + t.Run(tt.name, func(t *testing.T) { + c := newConn25(t.Logf) + c.reconfig(cfg) + c.connector.transitIPs = map[netip.Addr]map[netip.Addr]appAddr{} + c.connector.transitIPs[mappedSrc] = map[netip.Addr]appAddr{} + c.connector.transitIPs[mappedSrc][mappedTip] = appAddr{addr: mappedMip} + mip, err := c.ConnectorRealIPForTransitIPConnection(tt.src, tt.tip) + if mip != tt.wantMip { + t.Fatalf("checking magic ip: want %v, got %v", tt.wantMip, mip) + } + if err != tt.wantErr { + t.Fatalf("checking error: want %v, got %v", tt.wantErr, err) + } + }) + } +} + +func TestIsKnownTransitIP(t *testing.T) { + knownTip := netip.MustParseAddr("100.64.0.41") + unknownTip := netip.MustParseAddr("100.64.0.42") + + c := newConn25(t.Logf) + c.client.assignments.insert(addrs{ + transit: knownTip, + }) + + if !c.client.isKnownTransitIP(knownTip) { + t.Fatal("knownTip: should have been known") + } + if c.client.isKnownTransitIP(unknownTip) { + t.Fatal("unknownTip: should not have been known") + } +} + +func TestLinkLocalAllow(t *testing.T) { + knownTip := netip.MustParseAddr("100.64.0.41") + + c := newConn25(t.Logf) + c.client.assignments.insert(addrs{ + transit: knownTip, + }) + + if allow, _ := c.client.linkLocalAllow(packet.Parsed{ + Dst: netip.AddrPortFrom(knownTip, 1234), + }); !allow { + t.Fatal("knownTip: should have been allowed") + } + + if allow, _ := c.client.linkLocalAllow(packet.Parsed{ + Dst: netip.AddrPort{}, + }); allow { + t.Fatal("unknownTip: should not have been allowed") + } +} + +func TestConnectorPacketFilterAllow(t *testing.T) { + knownTip := netip.MustParseAddr("100.64.0.41") + knownSrc := netip.MustParseAddr("100.64.0.1") + unknownTip := netip.MustParseAddr("100.64.0.42") + unknownSrc := netip.MustParseAddr("100.64.0.42") + + c := newConn25(t.Logf) + c.connector.transitIPs = map[netip.Addr]map[netip.Addr]appAddr{} + c.connector.transitIPs[knownSrc] = map[netip.Addr]appAddr{} + c.connector.transitIPs[knownSrc][knownTip] = appAddr{} + + if allow, _ := c.connector.packetFilterAllow(packet.Parsed{ + Src: netip.AddrPortFrom(knownSrc, 1234), + Dst: netip.AddrPortFrom(knownTip, 1234), + }); !allow { + t.Fatal("knownTip: should have been allowed") + } + + if allow, _ := c.connector.packetFilterAllow(packet.Parsed{ + Src: netip.AddrPortFrom(unknownSrc, 1234), + Dst: netip.AddrPortFrom(knownTip, 1234), + }); allow { + t.Fatal("unknownSrc: should not have been allowed") + } + if allow, _ := c.connector.packetFilterAllow(packet.Parsed{ + Src: netip.AddrPortFrom(knownSrc, 1234), + Dst: netip.AddrPortFrom(unknownTip, 1234), + }); allow { + t.Fatal("unknownTip: should not have been allowed") + } +} + +func TestGetMagicRange(t *testing.T) { + sn := makeSelfNode(t, []appctype.Conn25Attr{{ + Name: "app1", + Connectors: []string{"tag:woo"}, + Domains: []string{"example.com"}, + V4MagicIPPool: []netipx.IPRange{netipx.IPRangeFrom(netip.MustParseAddr("0.0.0.1"), netip.MustParseAddr("0.0.0.3"))}, + V6MagicIPPool: []netipx.IPRange{netipx.IPRangeFrom(netip.MustParseAddr("::1"), netip.MustParseAddr("::3"))}, + }}, []string{}) + cfg := mustConfig(t, sn) + c := newConn25(t.Logf) + c.reconfig(cfg) + ext := &extension{ + conn25: c, + } + mRange := ext.getMagicRange() + somePrefixCovers := func(a netip.Addr) bool { + for _, r := range mRange.All() { + if r.Contains(a) { + return true + } + } + return false + } + ins := []string{ + "0.0.0.1", + "0.0.0.2", + "0.0.0.3", + "::1", + "::2", + "::3", + } + outs := []string{ + "0.0.0.0", + "0.0.0.4", + "::", + "::4", + } + for _, s := range ins { + if !somePrefixCovers(netip.MustParseAddr(s)) { + t.Fatalf("expected addr to be covered but was not: %s", s) + } + } + for _, s := range outs { + if somePrefixCovers(netip.MustParseAddr(s)) { + t.Fatalf("expected addr to NOT be covered but WAS: %s", s) + } + } +} + +func TestAssignmentsExpire(t *testing.T) { + clock := tstest.NewClock(tstest.ClockOpts{Start: time.Now()}) + assignments := addrAssignments{clock: clock} + as := addrs{ + dst: netip.MustParseAddr("0.0.0.1"), + magic: netip.MustParseAddr("0.0.0.2"), + transit: netip.MustParseAddr("0.0.0.3"), + app: "a", + domain: "example.com.", + } + err := assignments.insert(as) + if err != nil { + t.Fatal(err) + } + // Time has not passed since the insert, the assignment should be returned. + foundAs, ok := assignments.lookupByMagicIP(as.magic) + if !ok { + t.Fatal("expected to find") + } + if foundAs.dst != as.dst { + t.Fatalf("want %v; got %v", as.dst, foundAs.dst) + } + // and we cannot insert over the addresses + err = assignments.insert(as) + if err == nil { + t.Fatal("expected an error but got nil") + } + // After a time greater than the default expiry passes, the assignment should + // not be returned. + clock.Advance(defaultExpiry * 2) + foundAsAfter, okAfter := assignments.lookupByMagicIP(as.magic) + if okAfter { + t.Fatal("expected not to find (expired)") + } + if foundAsAfter.isValid() { + t.Fatal("expected zero val") + } + // Now we can reuse the addresses + err = assignments.insert(as) + if err != nil { + t.Fatal(err) + } + foundAs, ok = assignments.lookupByMagicIP(as.magic) + if !ok { + t.Fatal("expected to find") } - if got := len(c.client.assignments.byMagicIP); got != 1 { - t.Errorf("want 1 entry in byMagicIP, got %d", got) + if foundAs.dst != as.dst { + t.Fatalf("want %v; got %v", as.dst, foundAs.dst) } - if got := len(c.client.assignments.byDomainDst); got != 1 { - t.Errorf("want 1 entry in byDomainDst, got %d", got) + if !foundAs.expiresAt.After(clock.Now()) { + t.Fatalf("expected foundAs to expire after now") } } diff --git a/feature/conn25/datapath.go b/feature/conn25/datapath.go new file mode 100644 index 0000000000000..b5cdd51550042 --- /dev/null +++ b/feature/conn25/datapath.go @@ -0,0 +1,242 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package conn25 + +import ( + "errors" + "net/netip" + + "tailscale.com/envknob" + "tailscale.com/net/flowtrack" + "tailscale.com/net/packet" + "tailscale.com/net/packet/checksum" + "tailscale.com/types/ipproto" + "tailscale.com/types/logger" + "tailscale.com/wgengine/filter" +) + +var ( + ErrUnmappedMagicIP = errors.New("unmapped magic IP") + ErrUnmappedSrcAndTransitIP = errors.New("unmapped src and transit IP") +) + +// IPMapper provides methods for mapping special app connector IPs to each other +// in aid of performing DNAT and SNAT on app connector packets. +type IPMapper interface { + // ClientTransitIPForMagicIP returns a Transit IP for the given magicIP on a client. + // If the magicIP is within a configured Magic IP range for an app on the client, + // but not mapped to an active Transit IP, implementations should return [ErrUnmappedMagicIP]. + // If magicIP is not within a configured Magic IP range, i.e. it is not actually a Magic IP, + // implementations should return a nil error, and a zero-value [netip.Addr] to indicate + // this potentially valid, non-app-connector traffic. + ClientTransitIPForMagicIP(magicIP netip.Addr) (netip.Addr, error) + + // ConnectorRealIPForTransitIPConnection returns a real destination IP for the given + // srcIP and transitIP on a connector. If the transitIP is within a configured Transit IP + // range for an app on the connector, but not mapped to the client at srcIP, implementations + // should return [ErrUnmappedSrcAndTransitIP]. If the transitIP is not within a configured + // Transit IP range, i.e. it is not actually a Transit IP, implementations should return + // a nil error, and a zero-value [netip.Addr] to indicate this is potentially valid, + // non-app-connector traffic. + ConnectorRealIPForTransitIPConnection(srcIP netip.Addr, transitIP netip.Addr) (netip.Addr, error) +} + +// datapathHandler handles packets from the datapath, +// performing appropriate NAT operations to support Connectors 2025. +// It maintains [FlowTable] caches for fast lookups of established flows. +// +// When hooked into the main datapath filter chain in [tstun], the datapathHandler +// will see every packet on the node, regardless of whether it is relevant to +// app connector operations. In the common case of non-connector traffic, it +// passes the packet through unmodified. +// +// It classifies each packet based on the presence of special Magic IPs or +// Transit IPs, and determines whether the packet is flowing through a "client" +// (the node with the application that starts the connection), or a "connector" +// (the node that connects to the internet-hosted destination). On the client, +// outbound connections are DNATed from Magic IP to Transit IP, and return +// traffic is SNATed from Transit IP to Magic IP. On the connector, outbound +// connections are DNATed from Transit IP to real IP, and return traffic is +// SNATed from real IP to Transit IP. +// +// There are two exposed methods, one for handling packets from the tun device, +// and one for handling packets from WireGuard, but through the use of flow tables, +// we can handle four cases: client outbound, client return, connector outbound, +// connector return. The first packet goes through IPMapper, which is where Connectors +// 2025 authoritative state is stored. For valid packets relevant to connectors, +// a bidirectional flow entry is installed, so that subsequent packets (and all return traffic) +// hit that cache. Only outbound (towards internet) packets create new flows; return (from internet) +// packets either match a cached entry or pass through. +// +// We check the cache before IPMapper both for performance, and so that existing flows stay alive +// even if address mappings change mid-flow. +type datapathHandler struct { + ipMapper IPMapper + + // Flow caches. One for the client, and one for the connector. + clientFlowTable *FlowTable + connectorFlowTable *FlowTable + + logf logger.Logf + debugLogging bool +} + +func newDatapathHandler(ipMapper IPMapper, logf logger.Logf) *datapathHandler { + return &datapathHandler{ + ipMapper: ipMapper, + + // TODO(mzb): Figure out sensible default max size for flow tables. + // Don't do any LRU eviction until we figure out deletion and expiration. + clientFlowTable: NewFlowTable(0), + connectorFlowTable: NewFlowTable(0), + logf: logf, + debugLogging: envknob.Bool("TS_CONN25_DATAPATH_DEBUG"), + } +} + +// HandlePacketFromWireGuard inspects packets coming from WireGuard, and performs +// appropriate DNAT or SNAT actions for Connectors 2025. Returning [filter.Accept] signals +// that the packet should pass through subsequent stages of the datapath pipeline. +// Returning [filter.Drop] signals the packet should be dropped. This method handles all +// packets coming from WireGuard, on both connectors, and clients of connectors. +func (dh *datapathHandler) HandlePacketFromWireGuard(p *packet.Parsed) filter.Response { + // TODO(tailscale/corp#38764): Support other protocols, like ICMP for error messages. + if p.IPProto != ipproto.TCP && p.IPProto != ipproto.UDP { + return filter.Accept + } + + // Check if this is an existing (return) flow on a client. + // If found, perform the action for the existing client flow and return. + existing, ok := dh.clientFlowTable.LookupFromWireGuard(flowtrack.MakeTuple(p.IPProto, p.Src, p.Dst)) + if ok { + existing.Action(p) + return filter.Accept + } + + // Check if this is an existing connector outbound flow. + // If found, perform the action for the existing connector outbound flow and return. + existing, ok = dh.connectorFlowTable.LookupFromWireGuard(flowtrack.MakeTuple(p.IPProto, p.Src, p.Dst)) + if ok { + existing.Action(p) + return filter.Accept + } + + // The flow was not found in either flow table. Since the packet came in + // from WireGuard, it can only be a new flow on the connector, + // other (non-app-connector) traffic, or broken app-connector traffic + // that needs to be re-established by a new outbound packet. + transitIP := p.Dst.Addr() + realIP, err := dh.ipMapper.ConnectorRealIPForTransitIPConnection(p.Src.Addr(), transitIP) + if err != nil { + if errors.Is(err, ErrUnmappedSrcAndTransitIP) { + // TODO(tailscale/corp#34256): This path should deliver an ICMP error to the client. + return filter.Drop + } + dh.debugLogf("error mapping src and transit IP, passing packet unmodified: %v", err) + return filter.Accept + } + + // If this is normal non-app-connector traffic, forward it along unmodified. + if !realIP.IsValid() { + return filter.Accept + } + + // This is a new outbound flow on a connector. Install a DNAT TransitIP-to-RealIP action + // for the outgoing direction, and an SNAT RealIP-to-TransitIP action for the + // return direction. + outgoing := FlowData{ + Tuple: flowtrack.MakeTuple(p.IPProto, p.Src, p.Dst), + Action: dh.dnatAction(realIP), + } + incoming := FlowData{ + Tuple: flowtrack.MakeTuple(p.IPProto, netip.AddrPortFrom(realIP, p.Dst.Port()), p.Src), + Action: dh.snatAction(transitIP), + } + if err := dh.connectorFlowTable.NewFlowFromWireGuard(outgoing, incoming); err != nil { + dh.debugLogf("error installing flow, passing packet unmodified: %v", err) + return filter.Accept + } + outgoing.Action(p) + return filter.Accept +} + +// HandlePacketFromTunDevice inspects packets coming from the tun device, and performs +// appropriate DNAT or SNAT actions for Connectors 2025. Returning [filter.Accept] signals +// that the packet should pass through subsequent stages of the datapath pipeline. +// Returning [filter.Drop] signals the packet should be dropped. This method handles all +// packets coming from the tun device, on both connectors, and clients of connectors. +func (dh *datapathHandler) HandlePacketFromTunDevice(p *packet.Parsed) filter.Response { + // TODO(tailscale/corp#38764): Support other protocols, like ICMP for error messages. + if p.IPProto != ipproto.TCP && p.IPProto != ipproto.UDP { + return filter.Accept + } + + // Check if this is an existing client outbound flow. + // If found, perform the action for the existing client flow and return. + existing, ok := dh.clientFlowTable.LookupFromTunDevice(flowtrack.MakeTuple(p.IPProto, p.Src, p.Dst)) + if ok { + existing.Action(p) + return filter.Accept + } + + // Check if this is an existing connector return flow. + // If found, perform the action for the existing connector return flow and return. + existing, ok = dh.connectorFlowTable.LookupFromTunDevice(flowtrack.MakeTuple(p.IPProto, p.Src, p.Dst)) + if ok { + existing.Action(p) + return filter.Accept + } + + // The flow was not found in either flow table. Since the packet came in on the + // tun device, it can only be a new client flow, other (non-app-connector) traffic, + // or broken return app-connector traffic on a connector, which needs to be re-established + // with a new outbound packet. + magicIP := p.Dst.Addr() + transitIP, err := dh.ipMapper.ClientTransitIPForMagicIP(magicIP) + if err != nil { + if errors.Is(err, ErrUnmappedMagicIP) { + // TODO(tailscale/corp#34257): This path should deliver an ICMP error to the client. + return filter.Drop + } + dh.debugLogf("error mapping magic IP, passing packet unmodified: %v", err) + return filter.Accept + } + + // If this is normal non-app-connector traffic, forward it along unmodified. + if !transitIP.IsValid() { + return filter.Accept + } + + // This is a new outbound client flow. Install a DNAT MagicIP-to-TransitIP action + // for the outgoing direction, and an SNAT TransitIP-to-MagicIP action for the + // return direction. + outgoing := FlowData{ + Tuple: flowtrack.MakeTuple(p.IPProto, p.Src, p.Dst), + Action: dh.dnatAction(transitIP), + } + incoming := FlowData{ + Tuple: flowtrack.MakeTuple(p.IPProto, netip.AddrPortFrom(transitIP, p.Dst.Port()), p.Src), + Action: dh.snatAction(magicIP), + } + if err := dh.clientFlowTable.NewFlowFromTunDevice(outgoing, incoming); err != nil { + dh.debugLogf("error installing flow from tun device, passing packet unmodified: %v", err) + return filter.Accept + } + outgoing.Action(p) + return filter.Accept +} + +func (dh *datapathHandler) dnatAction(to netip.Addr) PacketAction { + return PacketAction(func(p *packet.Parsed) { checksum.UpdateDstAddr(p, to) }) +} + +func (dh *datapathHandler) snatAction(to netip.Addr) PacketAction { + return PacketAction(func(p *packet.Parsed) { checksum.UpdateSrcAddr(p, to) }) +} + +func (dh *datapathHandler) debugLogf(msg string, args ...any) { + if dh.debugLogging { + dh.logf(msg, args...) + } +} diff --git a/feature/conn25/datapath_test.go b/feature/conn25/datapath_test.go new file mode 100644 index 0000000000000..f75b89d29fed0 --- /dev/null +++ b/feature/conn25/datapath_test.go @@ -0,0 +1,361 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package conn25 + +import ( + "errors" + "net/netip" + "testing" + + "tailscale.com/net/packet" + "tailscale.com/types/ipproto" + "tailscale.com/wgengine/filter" +) + +type testConn25 struct { + clientTransitIPForMagicIPFn func(netip.Addr) (netip.Addr, error) + connectorRealIPForTransitIPConnectionFn func(netip.Addr, netip.Addr) (netip.Addr, error) +} + +func (tc *testConn25) ClientTransitIPForMagicIP(magicIP netip.Addr) (netip.Addr, error) { + return tc.clientTransitIPForMagicIPFn(magicIP) +} + +func (tc *testConn25) ConnectorRealIPForTransitIPConnection(srcIP netip.Addr, transitIP netip.Addr) (netip.Addr, error) { + return tc.connectorRealIPForTransitIPConnectionFn(srcIP, transitIP) +} + +func TestHandlePacketFromTunDevice(t *testing.T) { + clientSrcIP := netip.MustParseAddr("100.70.0.1") + magicIP := netip.MustParseAddr("10.64.0.1") + unusedMagicIP := netip.MustParseAddr("10.64.0.2") + transitIP := netip.MustParseAddr("169.254.0.1") + realIP := netip.MustParseAddr("240.64.0.1") + + clientPort := uint16(1234) + serverPort := uint16(80) + + tests := []struct { + description string + p *packet.Parsed + throwMappingErr bool + expectedSrc netip.AddrPort + expectedDst netip.AddrPort + expectedFilterResponse filter.Response + }{ + { + description: "accept-and-nat-new-client-flow-mapped-magic-ip", + p: &packet.Parsed{ + Src: netip.AddrPortFrom(clientSrcIP, clientPort), + Dst: netip.AddrPortFrom(magicIP, serverPort), + }, + expectedSrc: netip.AddrPortFrom(clientSrcIP, clientPort), + expectedDst: netip.AddrPortFrom(transitIP, serverPort), + expectedFilterResponse: filter.Accept, + }, + { + description: "drop-unmapped-magic-ip", + p: &packet.Parsed{ + Src: netip.AddrPortFrom(clientSrcIP, clientPort), + Dst: netip.AddrPortFrom(unusedMagicIP, serverPort), + }, + expectedSrc: netip.AddrPortFrom(clientSrcIP, clientPort), + expectedDst: netip.AddrPortFrom(unusedMagicIP, serverPort), + expectedFilterResponse: filter.Drop, + }, + { + description: "accept-dont-nat-other-mapping-error", + p: &packet.Parsed{ + Src: netip.AddrPortFrom(clientSrcIP, clientPort), + Dst: netip.AddrPortFrom(magicIP, serverPort), + }, + throwMappingErr: true, + expectedSrc: netip.AddrPortFrom(clientSrcIP, clientPort), + expectedDst: netip.AddrPortFrom(magicIP, serverPort), + expectedFilterResponse: filter.Accept, + }, + { + description: "accept-dont-nat-uninteresting-client-side", + p: &packet.Parsed{ + Src: netip.AddrPortFrom(clientSrcIP, clientPort), + Dst: netip.AddrPortFrom(realIP, serverPort), + }, + expectedSrc: netip.AddrPortFrom(clientSrcIP, clientPort), + expectedDst: netip.AddrPortFrom(realIP, serverPort), + expectedFilterResponse: filter.Accept, + }, + { + description: "accept-dont-nat-uninteresting-connector-side", + p: &packet.Parsed{ + Src: netip.AddrPortFrom(realIP, serverPort), + Dst: netip.AddrPortFrom(clientSrcIP, clientPort), + }, + expectedSrc: netip.AddrPortFrom(realIP, serverPort), + expectedDst: netip.AddrPortFrom(clientSrcIP, clientPort), + expectedFilterResponse: filter.Accept, + }, + } + + for _, tt := range tests { + t.Run(tt.description, func(t *testing.T) { + mock := &testConn25{} + mock.clientTransitIPForMagicIPFn = func(mip netip.Addr) (netip.Addr, error) { + if tt.throwMappingErr { + return netip.Addr{}, errors.New("synthetic mapping error") + } + if mip == magicIP { + return transitIP, nil + } + if mip == unusedMagicIP { + return netip.Addr{}, ErrUnmappedMagicIP + } + return netip.Addr{}, nil + } + dph := newDatapathHandler(mock, t.Logf) + + tt.p.IPProto = ipproto.UDP + tt.p.IPVersion = 4 + tt.p.StuffForTesting(40) + + if want, got := tt.expectedFilterResponse, dph.HandlePacketFromTunDevice(tt.p); want != got { + t.Errorf("unexpected filter response: want %v, got %v", want, got) + } + if want, got := tt.expectedSrc, tt.p.Src; want != got { + t.Errorf("unexpected packet src: want %v, got %v", want, got) + } + if want, got := tt.expectedDst, tt.p.Dst; want != got { + t.Errorf("unexpected packet dst: want %v, got %v", want, got) + } + }) + } +} + +func TestHandlePacketFromWireGuard(t *testing.T) { + clientSrcIP := netip.MustParseAddr("100.70.0.1") + unknownSrcIP := netip.MustParseAddr("100.99.99.99") + transitIP := netip.MustParseAddr("169.254.0.1") + realIP := netip.MustParseAddr("240.64.0.1") + + clientPort := uint16(1234) + serverPort := uint16(80) + + tests := []struct { + description string + p *packet.Parsed + throwMappingErr bool + expectedSrc netip.AddrPort + expectedDst netip.AddrPort + expectedFilterResponse filter.Response + }{ + { + description: "accept-and-nat-new-connector-flow-mapped-src-and-transit-ip", + p: &packet.Parsed{ + Src: netip.AddrPortFrom(clientSrcIP, clientPort), + Dst: netip.AddrPortFrom(transitIP, serverPort), + }, + expectedSrc: netip.AddrPortFrom(clientSrcIP, clientPort), + expectedDst: netip.AddrPortFrom(realIP, serverPort), + expectedFilterResponse: filter.Accept, + }, + { + description: "drop-unmapped-src-and-transit-ip", + p: &packet.Parsed{ + Src: netip.AddrPortFrom(unknownSrcIP, clientPort), + Dst: netip.AddrPortFrom(transitIP, serverPort), + }, + expectedSrc: netip.AddrPortFrom(unknownSrcIP, clientPort), + expectedDst: netip.AddrPortFrom(transitIP, serverPort), + expectedFilterResponse: filter.Drop, + }, + { + description: "accept-dont-nat-other-mapping-error", + p: &packet.Parsed{ + Src: netip.AddrPortFrom(clientSrcIP, clientPort), + Dst: netip.AddrPortFrom(transitIP, serverPort), + }, + throwMappingErr: true, + expectedSrc: netip.AddrPortFrom(clientSrcIP, clientPort), + expectedDst: netip.AddrPortFrom(transitIP, serverPort), + expectedFilterResponse: filter.Accept, + }, + { + description: "accept-dont-nat-uninteresting-connector-side", + p: &packet.Parsed{ + Src: netip.AddrPortFrom(clientSrcIP, clientPort), + Dst: netip.AddrPortFrom(realIP, serverPort), + }, + expectedSrc: netip.AddrPortFrom(clientSrcIP, clientPort), + expectedDst: netip.AddrPortFrom(realIP, serverPort), + expectedFilterResponse: filter.Accept, + }, + { + description: "accept-dont-nat-uninteresting-client-side", + p: &packet.Parsed{ + Src: netip.AddrPortFrom(realIP, serverPort), + Dst: netip.AddrPortFrom(clientSrcIP, clientPort), + }, + expectedSrc: netip.AddrPortFrom(realIP, serverPort), + expectedDst: netip.AddrPortFrom(clientSrcIP, clientPort), + expectedFilterResponse: filter.Accept, + }, + } + + for _, tt := range tests { + t.Run(tt.description, func(t *testing.T) { + mock := &testConn25{} + mock.connectorRealIPForTransitIPConnectionFn = func(src, tip netip.Addr) (netip.Addr, error) { + if tt.throwMappingErr { + return netip.Addr{}, errors.New("synthetic mapping error") + } + if tip == transitIP { + if src == clientSrcIP { + return realIP, nil + } else { + return netip.Addr{}, ErrUnmappedSrcAndTransitIP + } + } + return netip.Addr{}, nil + } + dph := newDatapathHandler(mock, t.Logf) + + tt.p.IPProto = ipproto.UDP + tt.p.IPVersion = 4 + tt.p.StuffForTesting(40) + + if want, got := tt.expectedFilterResponse, dph.HandlePacketFromWireGuard(tt.p); want != got { + t.Errorf("unexpected filter response: want %v, got %v", want, got) + } + if want, got := tt.expectedSrc, tt.p.Src; want != got { + t.Errorf("unexpected packet src: want %v, got %v", want, got) + } + if want, got := tt.expectedDst, tt.p.Dst; want != got { + t.Errorf("unexpected packet dst: want %v, got %v", want, got) + } + }) + } +} + +func TestClientFlowCache(t *testing.T) { + getTransitIPCalled := false + + clientSrcIP := netip.MustParseAddr("100.70.0.1") + magicIP := netip.MustParseAddr("10.64.0.1") + transitIP := netip.MustParseAddr("169.254.0.1") + + clientPort := uint16(1234) + serverPort := uint16(80) + + mock := &testConn25{} + mock.clientTransitIPForMagicIPFn = func(mip netip.Addr) (netip.Addr, error) { + if getTransitIPCalled { + t.Errorf("ClientGetTransitIPForMagicIP unexpectedly called more than once") + } + getTransitIPCalled = true + return transitIP, nil + } + dph := newDatapathHandler(mock, t.Logf) + + outgoing := packet.Parsed{ + IPProto: ipproto.UDP, + IPVersion: 4, + Src: netip.AddrPortFrom(clientSrcIP, clientPort), + Dst: netip.AddrPortFrom(magicIP, serverPort), + } + outgoing.StuffForTesting(40) + + o1 := outgoing + if dph.HandlePacketFromTunDevice(&o1) != filter.Accept { + t.Errorf("first call to HandlePacketFromTunDevice was not accepted") + } + if want, got := netip.AddrPortFrom(transitIP, serverPort), o1.Dst; want != got { + t.Errorf("unexpected packet dst after first call: want %v, got %v", want, got) + } + // The second call should use the cache. + o2 := outgoing + if dph.HandlePacketFromTunDevice(&o2) != filter.Accept { + t.Errorf("second call to HandlePacketFromTunDevice was not accepted") + } + if want, got := netip.AddrPortFrom(transitIP, serverPort), o2.Dst; want != got { + t.Errorf("unexpected packet dst after second call: want %v, got %v", want, got) + } + + // Return traffic should have the Transit IP as the source, + // and be SNATed to the Magic IP. + incoming := &packet.Parsed{ + IPProto: ipproto.UDP, + IPVersion: 4, + Src: netip.AddrPortFrom(transitIP, serverPort), + Dst: netip.AddrPortFrom(clientSrcIP, clientPort), + } + incoming.StuffForTesting(40) + + if dph.HandlePacketFromWireGuard(incoming) != filter.Accept { + t.Errorf("call to HandlePacketFromWireGuard was not accepted") + } + if want, got := netip.AddrPortFrom(magicIP, serverPort), incoming.Src; want != got { + t.Errorf("unexpected packet src after second call: want %v, got %v", want, got) + } +} + +func TestConnectorFlowCache(t *testing.T) { + getRealIPCalled := false + + clientSrcIP := netip.MustParseAddr("100.70.0.1") + transitIP := netip.MustParseAddr("169.254.0.1") + realIP := netip.MustParseAddr("240.64.0.1") + + clientPort := uint16(1234) + serverPort := uint16(80) + + mock := &testConn25{} + mock.connectorRealIPForTransitIPConnectionFn = func(src, tip netip.Addr) (netip.Addr, error) { + if getRealIPCalled { + t.Errorf("ConnectorRealIPForTransitIPConnection unexpectedly called more than once") + } + getRealIPCalled = true + return realIP, nil + } + dph := newDatapathHandler(mock, t.Logf) + + outgoing := packet.Parsed{ + IPProto: ipproto.UDP, + IPVersion: 4, + Src: netip.AddrPortFrom(clientSrcIP, clientPort), + Dst: netip.AddrPortFrom(transitIP, serverPort), + } + outgoing.StuffForTesting(40) + + o1 := outgoing + if dph.HandlePacketFromWireGuard(&o1) != filter.Accept { + t.Errorf("first call to HandlePacketFromWireGuard was not accepted") + } + if want, got := netip.AddrPortFrom(realIP, serverPort), o1.Dst; want != got { + t.Errorf("unexpected packet dst after first call: want %v, got %v", want, got) + } + // The second call should use the cache. + o2 := outgoing + if dph.HandlePacketFromWireGuard(&o2) != filter.Accept { + t.Errorf("second call to HandlePacketFromWireGuard was not accepted") + } + if want, got := netip.AddrPortFrom(realIP, serverPort), o2.Dst; want != got { + t.Errorf("unexpected packet dst after second call: want %v, got %v", want, got) + } + + // Return traffic should have the Real IP as the source, + // and be SNATed to the Transit IP. + incoming := &packet.Parsed{ + IPProto: ipproto.UDP, + IPVersion: 4, + Src: netip.AddrPortFrom(realIP, serverPort), + Dst: netip.AddrPortFrom(clientSrcIP, clientPort), + } + incoming.StuffForTesting(40) + + if dph.HandlePacketFromTunDevice(incoming) != filter.Accept { + t.Errorf("call to HandlePacketFromTunDevice was not accepted") + } + if want, got := netip.AddrPortFrom(transitIP, serverPort), incoming.Src; want != got { + t.Errorf("unexpected packet src after second call: want %v, got %v", want, got) + } +} diff --git a/feature/conn25/flowtable.go b/feature/conn25/flowtable.go new file mode 100644 index 0000000000000..27486ded910a2 --- /dev/null +++ b/feature/conn25/flowtable.go @@ -0,0 +1,149 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package conn25 + +import ( + "errors" + "sync" + + "tailscale.com/net/flowtrack" + "tailscale.com/net/packet" +) + +// PacketAction may modify the packet. +type PacketAction func(*packet.Parsed) + +// FlowData is an entry stored in the [FlowTable]. +type FlowData struct { + Tuple flowtrack.Tuple + Action PacketAction +} + +// Origin is used to track the direction of a flow. +type Origin uint8 + +const ( + // FromTun indicates the flow is from the tun device. + FromTun Origin = iota + + // FromWireGuard indicates the flow is from the WireGuard tunnel. + FromWireGuard +) + +type cachedFlow struct { + flow FlowData + paired flowtrack.Tuple // tuple for the other direction +} + +// FlowTable stores and retrieves [FlowData] that can be looked up +// by 5-tuple. New entries specify the tuple to use for both directions +// of traffic flow. The underlying cache is LRU, and the maximum number +// of entries is specified in calls to [NewFlowTable]. FlowTable has +// its own mutex and is safe for concurrent use. +type FlowTable struct { + mu sync.Mutex + fromTunCache *flowtrack.Cache[cachedFlow] // guarded by mu + fromWGCache *flowtrack.Cache[cachedFlow] // guarded by mu +} + +// NewFlowTable returns a [FlowTable] maxEntries maximum entries. +// A maxEntries of 0 indicates no maximum. See also [FlowTable]. +func NewFlowTable(maxEntries int) *FlowTable { + return &FlowTable{ + fromTunCache: &flowtrack.Cache[cachedFlow]{ + MaxEntries: maxEntries, + }, + fromWGCache: &flowtrack.Cache[cachedFlow]{ + MaxEntries: maxEntries, + }, + } +} + +// LookupFromTunDevice looks up a [FlowData] entry that is valid to run for packets +// observed as coming from the tun device. The tuple must match the direction it was +// stored with. +func (t *FlowTable) LookupFromTunDevice(k flowtrack.Tuple) (FlowData, bool) { + return t.lookup(k, FromTun) +} + +// LookupFromWireGuard looks up a [FlowData] entry that is valid to run for packets +// observed as coming from the WireGuard tunnel. The tuple must match the direction it was +// stored with. +func (t *FlowTable) LookupFromWireGuard(k flowtrack.Tuple) (FlowData, bool) { + return t.lookup(k, FromWireGuard) +} + +func (t *FlowTable) lookup(k flowtrack.Tuple, want Origin) (FlowData, bool) { + var cache *flowtrack.Cache[cachedFlow] + switch want { + case FromTun: + cache = t.fromTunCache + case FromWireGuard: + cache = t.fromWGCache + default: + return FlowData{}, false + } + + t.mu.Lock() + defer t.mu.Unlock() + + v, ok := cache.Get(k) + if !ok { + return FlowData{}, false + } + return v.flow, true +} + +// NewFlowFromTunDevice installs (or overwrites) both the forward and return entries. +// The forward tuple is tagged as FromTun, and the return tuple is tagged as FromWireGuard. +// If overwriting, it removes the old paired tuple for the forward key to avoid stale reverse mappings. +func (t *FlowTable) NewFlowFromTunDevice(fwd, rev FlowData) error { + return t.newFlow(FromTun, fwd, rev) +} + +// NewFlowFromWireGuard installs (or overwrites) both the forward and return entries. +// The forward tuple is tagged as FromWireGuard, and the return tuple is tagged as FromTun. +// If overwriting, it removes the old paired tuple for the forward key to avoid stale reverse mappings. +func (t *FlowTable) NewFlowFromWireGuard(fwd, rev FlowData) error { + return t.newFlow(FromWireGuard, fwd, rev) +} + +func (t *FlowTable) newFlow(fwdOrigin Origin, fwd, rev FlowData) error { + if fwd.Action == nil || rev.Action == nil { + return errors.New("nil action received for flow") + } + + var fwdCache, revCache *flowtrack.Cache[cachedFlow] + switch fwdOrigin { + case FromTun: + fwdCache, revCache = t.fromTunCache, t.fromWGCache + case FromWireGuard: + fwdCache, revCache = t.fromWGCache, t.fromTunCache + default: + return errors.New("newFlow called with unknown direction") + } + + t.mu.Lock() + defer t.mu.Unlock() + + // If overwriting an existing entry, remove its previously-paired mapping so + // we don't leave stale tuples around. + if old, ok := fwdCache.Get(fwd.Tuple); ok { + revCache.Remove(old.paired) + } + if old, ok := revCache.Get(rev.Tuple); ok { + fwdCache.Remove(old.paired) + } + + fwdCache.Add(fwd.Tuple, cachedFlow{ + flow: fwd, + paired: rev.Tuple, + }) + revCache.Add(rev.Tuple, cachedFlow{ + flow: rev, + paired: fwd.Tuple, + }) + + return nil +} diff --git a/feature/conn25/flowtable_test.go b/feature/conn25/flowtable_test.go new file mode 100644 index 0000000000000..8c3cd63a20e77 --- /dev/null +++ b/feature/conn25/flowtable_test.go @@ -0,0 +1,125 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package conn25 + +import ( + "net/netip" + "testing" + + "tailscale.com/net/flowtrack" + "tailscale.com/net/packet" + "tailscale.com/types/ipproto" +) + +func TestFlowTable(t *testing.T) { + ft := NewFlowTable(0) + + fwdTuple := flowtrack.MakeTuple( + ipproto.UDP, + netip.MustParseAddrPort("1.2.3.4:1000"), + netip.MustParseAddrPort("4.3.2.1:80"), + ) + // Reverse tuple is defined by caller. Doesn't have to be mirror image of fwd. + // To account for intentional modifications, like NAT. + revTuple := flowtrack.MakeTuple( + ipproto.UDP, + netip.MustParseAddrPort("4.3.2.2:80"), + netip.MustParseAddrPort("1.2.3.4:1000"), + ) + + fwdAction, revAction := 0, 0 + fwdData := FlowData{ + Tuple: fwdTuple, + Action: func(_ *packet.Parsed) { fwdAction++ }, + } + revData := FlowData{ + Tuple: revTuple, + Action: func(_ *packet.Parsed) { revAction++ }, + } + + // For this test setup, from the tun device will be "forward", + // and from WG will be "reverse". + if err := ft.NewFlowFromTunDevice(fwdData, revData); err != nil { + t.Fatalf("got non-nil error for new flow from tun device") + } + + // Test basic lookups. + lookupFwd, ok := ft.LookupFromTunDevice(fwdTuple) + if !ok { + t.Fatalf("got not found on first lookup from tun device") + } + lookupFwd.Action(nil) + if fwdAction != 1 { + t.Errorf("action for fwd tuple key was not executed") + } + + lookupRev, ok := ft.LookupFromWireGuard(revTuple) + if !ok { + t.Fatalf("got not found on first lookup from WireGuard") + } + lookupRev.Action(nil) + if revAction != 1 { + t.Errorf("action for rev tuple key was not executed") + } + + // Test not found error. + notFoundTuple := flowtrack.MakeTuple( + ipproto.UDP, + netip.MustParseAddrPort("1.2.3.4:1000"), + netip.MustParseAddrPort("4.0.4.4:80"), + ) + if _, ok := ft.LookupFromTunDevice(notFoundTuple); ok { + t.Errorf("expected not found for foreign tuple") + } + + // Wrong direction is also not found. + if _, ok := ft.LookupFromWireGuard(fwdTuple); ok { + t.Errorf("expected not found for wrong direction tuple") + } + + // Overwriting forward tuple removes its reverse pair as well. + newRevData := FlowData{ + Tuple: flowtrack.MakeTuple( + ipproto.UDP, + netip.MustParseAddrPort("9.9.9.9:99"), + netip.MustParseAddrPort("8.8.8.8:88"), + ), + Action: func(_ *packet.Parsed) {}, + } + if err := ft.NewFlowFromTunDevice( + fwdData, + newRevData, + ); err != nil { + t.Fatalf("got non-nil error for new flow from tun device") + } + if _, ok := ft.LookupFromWireGuard(revTuple); ok { + t.Errorf("expected not found for removed reverse tuple") + } + + // Overwriting reverse tuple removes its forward pair as well. + if err := ft.NewFlowFromTunDevice( + FlowData{ + Tuple: flowtrack.MakeTuple( + ipproto.UDP, + netip.MustParseAddrPort("8.8.8.8:88"), + netip.MustParseAddrPort("9.9.9.9:99"), + ), + Action: func(_ *packet.Parsed) {}, + }, + newRevData, // This is the same "reverse" data installed in previous test. + ); err != nil { + t.Fatalf("got non-nil error for new flow from tun device") + } + if _, ok := ft.LookupFromTunDevice(fwdTuple); ok { + t.Errorf("expected not found for removed forward tuple") + } + + // Nil action returns an error. + if err := ft.NewFlowFromTunDevice( + FlowData{}, + FlowData{}, + ); err == nil { + t.Errorf("expected non-nil error for nil data") + } +} diff --git a/feature/conn25/ippool.go b/feature/conn25/ippool.go index e50186d880914..4ae8918d49397 100644 --- a/feature/conn25/ippool.go +++ b/feature/conn25/ippool.go @@ -8,17 +8,24 @@ import ( "net/netip" "go4.org/netipx" + "tailscale.com/util/set" ) // errPoolExhausted is returned when there are no more addresses to iterate over. var errPoolExhausted = errors.New("ip pool exhausted") -// ippool allows for iteration over all the addresses within a netipx.IPSet. +// errNotOurAddress is returned if a provided address is not from our pool +var errNotOurAddress = errors.New("not our address") + +// errAddrExists is returned if a returned address is already in the returned pool. +var errAddrExists = errors.New("address already returned") + +// ipSetIterator allows for round robin iteration over all the addresses within a netipx.IPSet. // netipx.IPSet has a Ranges call that returns the "minimum and sorted set of IP ranges that covers [the set]". // netipx.IPRange is "an inclusive range of IP addresses from the same address family.". So we can iterate over // all the addresses in the set by keeping a track of the last address we returned, calling Next on the last address -// to get the new one, and if we run off the edge of the current range, starting on the next one. -type ippool struct { +// to get the new one, and if we run off the edge of the current range, starting on the next one, or back at the beginning. +type ipSetIterator struct { // ranges defines the addresses in the pool ranges []netipx.IPRange // last is internal tracking of which the last address provided was. @@ -27,35 +34,75 @@ type ippool struct { rangeIdx int } +// next returns the next address from the set. +func (ipsi *ipSetIterator) next() (netip.Addr, error) { + if len(ipsi.ranges) == 0 { + // ipset is empty + return netip.Addr{}, errPoolExhausted + } + if !ipsi.last.IsValid() { + // not initialized yet + ipsi.last = ipsi.ranges[0].From() + return ipsi.last, nil + } + currRange := ipsi.ranges[ipsi.rangeIdx] + if ipsi.last == currRange.To() { + // then we need to move to the next range + ipsi.rangeIdx++ + if ipsi.rangeIdx >= len(ipsi.ranges) { + // back to the beginning + ipsi.rangeIdx = 0 + } + ipsi.last = ipsi.ranges[ipsi.rangeIdx].From() + return ipsi.last, nil + } + ipsi.last = ipsi.last.Next() + return ipsi.last, nil +} + func newIPPool(ipset *netipx.IPSet) *ippool { if ipset == nil { return &ippool{} } - return &ippool{ranges: ipset.Ranges()} + return &ippool{ + ipSet: ipset, + ipSetIterator: &ipSetIterator{ranges: ipset.Ranges()}, + inUse: &set.Set[netip.Addr]{}, + } +} + +type ippool struct { + ipSet *netipx.IPSet + ipSetIterator *ipSetIterator + inUse *set.Set[netip.Addr] } -// next returns the next address from the set, or errPoolExhausted if we have -// iterated over the whole set. func (ipp *ippool) next() (netip.Addr, error) { - if ipp.rangeIdx >= len(ipp.ranges) { - // ipset is empty or we have iterated off the end - return netip.Addr{}, errPoolExhausted - } - if !ipp.last.IsValid() { - // not initialized yet - ipp.last = ipp.ranges[0].From() - return ipp.last, nil + a, err := ipp.ipSetIterator.next() + if err != nil { + return netip.Addr{}, err } - currRange := ipp.ranges[ipp.rangeIdx] - if ipp.last == currRange.To() { - // then we need to move to the next range - ipp.rangeIdx++ - if ipp.rangeIdx >= len(ipp.ranges) { + startedAt := a + for ipp.inUse.Contains(a) { + a, err = ipp.ipSetIterator.next() + if err != nil { + return a, err + } + if a == startedAt { return netip.Addr{}, errPoolExhausted } - ipp.last = ipp.ranges[ipp.rangeIdx].From() - return ipp.last, nil } - ipp.last = ipp.last.Next() - return ipp.last, nil + ipp.inUse.Add(a) + return a, nil +} + +func (ipp *ippool) returnAddr(a netip.Addr) error { + if !ipp.ipSet.Contains(a) { + return errNotOurAddress + } + if !ipp.inUse.Contains(a) { + return errAddrExists + } + ipp.inUse.Delete(a) + return nil } diff --git a/feature/conn25/ippool_test.go b/feature/conn25/ippool_test.go index ccfaad3eb71e1..431ea6998ac84 100644 --- a/feature/conn25/ippool_test.go +++ b/feature/conn25/ippool_test.go @@ -13,7 +13,7 @@ import ( ) func TestNext(t *testing.T) { - a := ippool{} + a := ipSetIterator{} _, err := a.next() if !errors.Is(err, errPoolExhausted) { t.Fatalf("expected errPoolExhausted, got %v", err) @@ -58,3 +58,88 @@ func TestNext(t *testing.T) { t.Fatalf("expected errPoolExhausted, got %v", err) } } + +// TestReturnAddr tests that if a pool is exhausted, an address can be returned to the +// pool, and then that address will be handed out again. +func TestReturnAddr(t *testing.T) { + addrString := "192.168.0.0" + // There's an IPPool with one address in it. + var isb netipx.IPSetBuilder + isb.AddRange(netipx.IPRangeFrom(netip.MustParseAddr(addrString), netip.MustParseAddr(addrString))) + ipset := must.Get(isb.IPSet()) + ipp := newIPPool(ipset) + // The first time we call next we get the address. + addr, err := ipp.next() + if err != nil { + t.Fatalf("expected nil error, got: %v", err) + } + if addr != netip.MustParseAddr(addrString) { + t.Fatalf("want %v, got %v", addrString, addr) + } + // The second time we call next we get errPoolExhausted + _, err = ipp.next() + if !errors.Is(err, errPoolExhausted) { + t.Fatalf("expected errPoolExhausted, got %v", err) + } + // Return the addr to the pool + err = ipp.returnAddr(netip.MustParseAddr(addrString)) + if err != nil { + t.Fatal(err) + } + // It's not possible to return addresses that are already in the pool. + err = ipp.returnAddr(netip.MustParseAddr(addrString)) + if !errors.Is(err, errAddrExists) { + t.Fatalf("want errAddrExists, got: %v", err) + } + // When we call next we get the returned addr + addrAfterReturn, err := ipp.next() + if err != nil { + t.Fatalf("expected nil error, got: %v", err) + } + if addrAfterReturn != netip.MustParseAddr(addrString) { + t.Fatalf("want %v, got %v", addrString, addrAfterReturn) + } + // You can't return addresses that aren't from the pool. + err = ipp.returnAddr(netip.MustParseAddr("100.100.100.0")) + if !errors.Is(err, errNotOurAddress) { + t.Fatalf("want errNotOurAddress, got: %v", err) + } +} + +// TestGettingReturnedAddresses tests that when addresses are returned to the IP Pool +// they are then handed out in the order they were returned. +func TestGettingReturnedAddresses(t *testing.T) { + var isb netipx.IPSetBuilder + isb.AddRange(netipx.IPRangeFrom(netip.MustParseAddr("192.168.0.0"), netip.MustParseAddr("192.168.0.4"))) + ipset := must.Get(isb.IPSet()) + ipp := newIPPool(ipset) + expectAddrNext := func(addrString string) { + t.Helper() + got, err := ipp.next() + if err != nil { + t.Fatalf("expected nil error, got: %v", err) + } + want := netip.MustParseAddr(addrString) + if want != got { + t.Fatalf("want %v; got %v", want, got) + } + } + expectErrPoolExhaustedNext := func() { + t.Helper() + _, err := ipp.next() + if !errors.Is(err, errPoolExhausted) { + t.Fatalf("expected errPoolExhausted; got %v", err) + } + } + expectAddrNext("192.168.0.0") + expectAddrNext("192.168.0.1") + expectAddrNext("192.168.0.2") + expectAddrNext("192.168.0.3") + expectAddrNext("192.168.0.4") + expectErrPoolExhaustedNext() + ipp.returnAddr(netip.MustParseAddr("192.168.0.2")) + ipp.returnAddr(netip.MustParseAddr("192.168.0.4")) + expectAddrNext("192.168.0.2") + expectAddrNext("192.168.0.4") + expectErrPoolExhaustedNext() +} diff --git a/feature/doctor/doctor.go b/feature/doctor/doctor.go index db061311b2e1f..01897f0a6478a 100644 --- a/feature/doctor/doctor.go +++ b/feature/doctor/doctor.go @@ -63,7 +63,7 @@ func visitDoctor(ctx context.Context, b *ipnlocal.LocalBackend, logf logger.Logf // IPs; this can interfere with our ability to connect to the Tailscale // controlplane. checks = append(checks, doctor.CheckFunc("dns-resolvers", func(_ context.Context, logf logger.Logf) error { - nm := b.NetMap() + nm := b.NetMapNoPeers() if nm == nil { return nil } diff --git a/feature/featuretags/featuretags.go b/feature/featuretags/featuretags.go index 4220c02b75fa2..e44a4f5922cc2 100644 --- a/feature/featuretags/featuretags.go +++ b/feature/featuretags/featuretags.go @@ -168,9 +168,9 @@ var Features = map[FeatureTag]FeatureMeta{ "health": {Sym: "Health", Desc: "Health checking support"}, "hujsonconf": {Sym: "HuJSONConf", Desc: "HuJSON config file support"}, "identityfederation": {Sym: "IdentityFederation", Desc: "Auth key generation via identity federation support"}, + "ipnbus": {Sym: "IPNBus", Desc: "IPN notification bus (watch-ipn-bus) support, used by GUIs, debugging, and nicer 'tailscale up' support"}, "iptables": {Sym: "IPTables", Desc: "Linux iptables support"}, "kube": {Sym: "Kube", Desc: "Kubernetes integration"}, - "lazywg": {Sym: "LazyWG", Desc: "Lazy WireGuard configuration for memory-constrained devices with large netmaps"}, "linuxdnsfight": {Sym: "LinuxDNSFight", Desc: "Linux support for detecting DNS fights (inotify watching of /etc/resolv.conf)"}, "linkspeed": { Sym: "LinkSpeed", @@ -269,6 +269,7 @@ var Features = map[FeatureTag]FeatureMeta{ "tailnetlock": {Sym: "TailnetLock", Desc: "Tailnet Lock support"}, "tap": {Sym: "Tap", Desc: "Experimental Layer 2 (ethernet) support"}, "tpm": {Sym: "TPM", Desc: "TPM support"}, + "tundevstats": {Sym: "TUNDevStats", Desc: "Poll TUN device statistics (Linux only)"}, "unixsocketidentity": { Sym: "UnixSocketIdentity", Desc: "differentiate between users accessing the LocalAPI over unix sockets (if omitted, all users have full access)", diff --git a/feature/hooks.go b/feature/hooks.go index 5cd3c0d818ca6..7611499a19011 100644 --- a/feature/hooks.go +++ b/feature/hooks.go @@ -67,6 +67,11 @@ func TPMAvailable() bool { return false } +// HookGetSSHHostKeyPublicStrings is a hook for the ssh/hostkeys package to +// provide SSH host key public strings to ipn/ipnlocal without ipnlocal needing +// to import golang.org/x/crypto/ssh. +var HookGetSSHHostKeyPublicStrings Hook[func(varRoot string, logf logger.Logf) ([]string, error)] + // HookHardwareAttestationAvailable is a hook that reports whether hardware // attestation is supported and available. var HookHardwareAttestationAvailable Hook[func() bool] diff --git a/feature/identityfederation/identityfederation.go b/feature/identityfederation/identityfederation.go index 4b96fd6a2020c..51a8018d8644d 100644 --- a/feature/identityfederation/identityfederation.go +++ b/feature/identityfederation/identityfederation.go @@ -128,8 +128,7 @@ func exchangeJWTForToken(ctx context.Context, baseURL, clientID, idToken string) }).Exchange(ctx, "", oauth2.SetAuthURLParam("client_id", clientID), oauth2.SetAuthURLParam("jwt", idToken)) if err != nil { // Try to extract more detailed error message - var retrieveErr *oauth2.RetrieveError - if errors.As(err, &retrieveErr) { + if retrieveErr, ok := errors.AsType[*oauth2.RetrieveError](err); ok { return "", fmt.Errorf("token exchange failed with status %d: %s", retrieveErr.Response.StatusCode, string(retrieveErr.Body)) } return "", fmt.Errorf("unexpected token exchange request error: %w", err) diff --git a/feature/identityfederation/identityfederation_test.go b/feature/identityfederation/identityfederation_test.go index 5e3660dc58725..9d9e5f4fdeaa5 100644 --- a/feature/identityfederation/identityfederation_test.go +++ b/feature/identityfederation/identityfederation_test.go @@ -31,7 +31,7 @@ func TestResolveAuthKey(t *testing.T) { wantErr: "", }, { - name: "missing client id short-circuits without error", + name: "missing-client-id-noop", clientID: "", idToken: "token", audience: "api://tailscale-wif", @@ -40,7 +40,7 @@ func TestResolveAuthKey(t *testing.T) { wantErr: "", }, { - name: "missing id token and audience", + name: "missing-id-token-and-audience", clientID: "client-123", idToken: "", audience: "", @@ -48,7 +48,7 @@ func TestResolveAuthKey(t *testing.T) { wantErr: "federated identity requires either an ID token or an audience", }, { - name: "missing tags", + name: "missing-tags", clientID: "client-123", idToken: "token", audience: "api://tailscale-wif", @@ -56,7 +56,7 @@ func TestResolveAuthKey(t *testing.T) { wantErr: "federated identity authkeys require --advertise-tags", }, { - name: "invalid client id attributes", + name: "invalid-client-id-attrs", clientID: "client-123?invalid=value", idToken: "token", audience: "api://tailscale-wif", @@ -99,7 +99,7 @@ func TestParseOptionalAttributes(t *testing.T) { wantErr string }{ { - name: "default values", + name: "default-values", clientID: "client-123", wantClientID: "client-123", wantEphemeral: true, @@ -107,7 +107,7 @@ func TestParseOptionalAttributes(t *testing.T) { wantErr: "", }, { - name: "custom values", + name: "custom-values", clientID: "client-123?ephemeral=false&preauthorized=true", wantClientID: "client-123", wantEphemeral: false, @@ -115,7 +115,7 @@ func TestParseOptionalAttributes(t *testing.T) { wantErr: "", }, { - name: "unknown attribute", + name: "unknown-attribute", clientID: "client-123?unknown=value", wantClientID: "", wantEphemeral: false, @@ -123,7 +123,7 @@ func TestParseOptionalAttributes(t *testing.T) { wantErr: `unknown optional config attribute "unknown"`, }, { - name: "invalid value", + name: "invalid-value", clientID: "client-123?ephemeral=invalid", wantClientID: "", wantEphemeral: false, diff --git a/feature/linuxdnsfight/linuxdnsfight_test.go b/feature/linuxdnsfight/linuxdnsfight_test.go index 661ba7f6f3a00..ce67353db297c 100644 --- a/feature/linuxdnsfight/linuxdnsfight_test.go +++ b/feature/linuxdnsfight/linuxdnsfight_test.go @@ -42,7 +42,7 @@ func TestWatchFile(t *testing.T) { // Keep writing until we get a callback. func() { for i := range 10000 { - if err := os.WriteFile(filepath, []byte(fmt.Sprintf("write%d", i)), 0644); err != nil { + if err := os.WriteFile(filepath, fmt.Appendf(nil, "write%d", i), 0644); err != nil { t.Fatal(err) } select { diff --git a/feature/oauthkey/oauthkey_test.go b/feature/oauthkey/oauthkey_test.go index f8027e45a922e..bb1de932662ff 100644 --- a/feature/oauthkey/oauthkey_test.go +++ b/feature/oauthkey/oauthkey_test.go @@ -20,42 +20,42 @@ func TestResolveAuthKey(t *testing.T) { wantErr bool }{ { - name: "keys without client secret prefix pass through unchanged", + name: "non-client-secret-passthrough", clientID: "tskey-auth-regular", tags: []string{"tag:test"}, wantAuthKey: "tskey-auth-regular", wantErr: false, }, { - name: "client secret without advertised tags", + name: "client-secret-no-tags", clientID: "tskey-client-abc", tags: nil, wantAuthKey: "", wantErr: true, }, { - name: "client secret with default attributes", + name: "client-secret-default-attrs", clientID: "tskey-client-abc", tags: []string{"tag:test"}, wantAuthKey: "tskey-auth-xyz", wantErr: false, }, { - name: "client secret with custom attributes", + name: "client-secret-custom-attrs", clientID: "tskey-client-abc?ephemeral=false&preauthorized=true", tags: []string{"tag:test"}, wantAuthKey: "tskey-auth-xyz", wantErr: false, }, { - name: "client secret with unknown attribute", + name: "client-secret-unknown-attr", clientID: "tskey-client-abc?unknown=value", tags: []string{"tag:test"}, wantAuthKey: "", wantErr: true, }, { - name: "oauth client secret with invalid attribute value", + name: "client-secret-invalid-attr-value", clientID: "tskey-client-abc?ephemeral=invalid", tags: []string{"tag:test"}, wantAuthKey: "", @@ -111,7 +111,7 @@ func TestResolveAuthKeyAttributes(t *testing.T) { wantBaseURL string }{ { - name: "default values", + name: "default-values", clientSecret: "tskey-client-abc", wantEphemeral: true, wantPreauth: false, @@ -132,14 +132,14 @@ func TestResolveAuthKeyAttributes(t *testing.T) { wantBaseURL: "https://api.tailscale.com", }, { - name: "baseURL=https://api.example.com", + name: "baseURL-custom", clientSecret: "tskey-client-abc?baseURL=https://api.example.com", wantEphemeral: true, wantPreauth: false, wantBaseURL: "https://api.example.com", }, { - name: "all custom values", + name: "all-custom-values", clientSecret: "tskey-client-abc?ephemeral=false&preauthorized=true&baseURL=https://api.example.com", wantEphemeral: false, wantPreauth: true, diff --git a/feature/posture/posture.go b/feature/posture/posture.go index d8db1ac1933fb..0c60d38b07601 100644 --- a/feature/posture/posture.go +++ b/feature/posture/posture.go @@ -8,8 +8,10 @@ package posture import ( "encoding/json" + "fmt" "net/http" + "tailscale.com/health" "tailscale.com/ipn/ipnext" "tailscale.com/ipn/ipnlocal" "tailscale.com/posture" @@ -25,6 +27,15 @@ func init() { ipnlocal.RegisterC2N("GET /posture/identity", handleC2NPostureIdentityGet) } +var postureSerialWarnable = health.Register(&health.Warnable{ + Code: "posture-checking-serial-collection-failed", + Title: "Device Posture: serial number collection failed", + Severity: health.SeverityMedium, + Text: func(args health.Args) string { + return fmt.Sprintf("Could not collect device serial numbers for posture checking. (%v)", args[health.ArgError]) + }, +}) + func newExtension(logf logger.Logf, b ipnext.SafeBackend) (ipnext.Extension, error) { e := &extension{ logf: logger.WithPrefix(logf, "posture: "), @@ -73,6 +84,9 @@ func handleC2NPostureIdentityGet(b *ipnlocal.LocalBackend, w http.ResponseWriter res.SerialNumbers, err = posture.GetSerialNumbers(b.PolicyClient(), e.logf) if err != nil { e.logf("c2n: GetSerialNumbers returned error: %v", err) + b.HealthTracker().SetUnhealthy(postureSerialWarnable, health.Args{health.ArgError: err.Error()}) + } else { + b.HealthTracker().SetHealthy(postureSerialWarnable) } // TODO(tailscale/corp#21371, 2024-07-10): once this has landed in a stable release diff --git a/feature/relayserver/relayserver.go b/feature/relayserver/relayserver.go index 45d6abcc1d3d6..4f52a7ca748e7 100644 --- a/feature/relayserver/relayserver.go +++ b/feature/relayserver/relayserver.go @@ -23,7 +23,6 @@ import ( "tailscale.com/tailcfg" "tailscale.com/types/key" "tailscale.com/types/logger" - "tailscale.com/types/ptr" "tailscale.com/types/views" "tailscale.com/util/eventbus" "tailscale.com/wgengine/magicsock" @@ -225,7 +224,7 @@ func (e *extension) profileStateChanged(_ ipn.LoginProfileView, prefs ipn.PrefsV e.stopRelayServerLocked() e.port = nil if ok { - e.port = ptr.To(newPort) + e.port = new(newPort) } } e.handleRelayServerLifetimeLocked() @@ -264,7 +263,7 @@ func (e *extension) serverStatus() status.ServerStatus { if e.rs == nil { return st } - st.UDPPort = ptr.To(*e.port) + st.UDPPort = new(*e.port) st.Sessions = e.rs.GetSessions() return st } diff --git a/feature/relayserver/relayserver_test.go b/feature/relayserver/relayserver_test.go index 730e25a00d0d3..97f4eb87418b6 100644 --- a/feature/relayserver/relayserver_test.go +++ b/feature/relayserver/relayserver_test.go @@ -18,15 +18,14 @@ import ( "tailscale.com/tstime" "tailscale.com/types/key" "tailscale.com/types/logger" - "tailscale.com/types/ptr" "tailscale.com/types/views" ) func Test_extension_profileStateChanged(t *testing.T) { - prefsWithPortOne := ipn.Prefs{RelayServerPort: ptr.To(uint16(1))} + prefsWithPortOne := ipn.Prefs{RelayServerPort: new(uint16(1))} prefsWithNilPort := ipn.Prefs{RelayServerPort: nil} prefsWithPortOneRelayEndpoints := ipn.Prefs{ - RelayServerPort: ptr.To(uint16(1)), + RelayServerPort: new(uint16(1)), RelayServerStaticEndpoints: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:7777")}, } @@ -49,36 +48,36 @@ func Test_extension_profileStateChanged(t *testing.T) { wantEndpoints []netip.AddrPort }{ { - name: "no changes non-nil port previously running", + name: "no-changes-non-nil-port-running", fields: fields{ - port: ptr.To(uint16(1)), + port: new(uint16(1)), rs: mockRelayServerNotZeroVal(), }, args: args{ prefs: prefsWithPortOne.View(), sameNode: true, }, - wantPort: ptr.To(uint16(1)), + wantPort: new(uint16(1)), wantRelayServerFieldNonNil: true, wantRelayServerFieldMutated: false, }, { - name: "set addr ports unchanged port previously running", + name: "set-addr-ports-unchanged-running", fields: fields{ - port: ptr.To(uint16(1)), + port: new(uint16(1)), rs: mockRelayServerNotZeroVal(), }, args: args{ prefs: prefsWithPortOneRelayEndpoints.View(), sameNode: true, }, - wantPort: ptr.To(uint16(1)), + wantPort: new(uint16(1)), wantRelayServerFieldNonNil: true, wantRelayServerFieldMutated: false, wantEndpoints: prefsWithPortOneRelayEndpoints.RelayServerStaticEndpoints, }, { - name: "set addr ports not previously running", + name: "set-addr-ports-not-running", fields: fields{ port: nil, rs: nil, @@ -87,15 +86,15 @@ func Test_extension_profileStateChanged(t *testing.T) { prefs: prefsWithPortOneRelayEndpoints.View(), sameNode: true, }, - wantPort: ptr.To(uint16(1)), + wantPort: new(uint16(1)), wantRelayServerFieldNonNil: true, wantRelayServerFieldMutated: true, wantEndpoints: prefsWithPortOneRelayEndpoints.RelayServerStaticEndpoints, }, { - name: "clear addr ports unchanged port previously running", + name: "clear-addr-ports-unchanged-running", fields: fields{ - port: ptr.To(uint16(1)), + port: new(uint16(1)), staticEndpoints: views.SliceOf(prefsWithPortOneRelayEndpoints.RelayServerStaticEndpoints), rs: mockRelayServerNotZeroVal(), }, @@ -103,15 +102,15 @@ func Test_extension_profileStateChanged(t *testing.T) { prefs: prefsWithPortOne.View(), sameNode: true, }, - wantPort: ptr.To(uint16(1)), + wantPort: new(uint16(1)), wantRelayServerFieldNonNil: true, wantRelayServerFieldMutated: false, wantEndpoints: nil, }, { - name: "prefs port nil", + name: "prefs-port-nil", fields: fields{ - port: ptr.To(uint16(1)), + port: new(uint16(1)), }, args: args{ prefs: prefsWithNilPort.View(), @@ -122,9 +121,9 @@ func Test_extension_profileStateChanged(t *testing.T) { wantRelayServerFieldMutated: false, }, { - name: "prefs port nil previously running", + name: "prefs-port-nil-running", fields: fields{ - port: ptr.To(uint16(1)), + port: new(uint16(1)), rs: mockRelayServerNotZeroVal(), }, args: args{ @@ -136,61 +135,61 @@ func Test_extension_profileStateChanged(t *testing.T) { wantRelayServerFieldMutated: true, }, { - name: "prefs port changed", + name: "prefs-port-changed", fields: fields{ - port: ptr.To(uint16(2)), + port: new(uint16(2)), }, args: args{ prefs: prefsWithPortOne.View(), sameNode: true, }, - wantPort: ptr.To(uint16(1)), + wantPort: new(uint16(1)), wantRelayServerFieldNonNil: true, wantRelayServerFieldMutated: true, }, { - name: "prefs port changed previously running", + name: "prefs-port-changed-running", fields: fields{ - port: ptr.To(uint16(2)), + port: new(uint16(2)), rs: mockRelayServerNotZeroVal(), }, args: args{ prefs: prefsWithPortOne.View(), sameNode: true, }, - wantPort: ptr.To(uint16(1)), + wantPort: new(uint16(1)), wantRelayServerFieldNonNil: true, wantRelayServerFieldMutated: true, }, { - name: "sameNode false", + name: "sameNode-false", fields: fields{ - port: ptr.To(uint16(1)), + port: new(uint16(1)), }, args: args{ prefs: prefsWithPortOne.View(), sameNode: false, }, - wantPort: ptr.To(uint16(1)), + wantPort: new(uint16(1)), wantRelayServerFieldNonNil: true, wantRelayServerFieldMutated: true, }, { - name: "sameNode false previously running", + name: "sameNode-false-running", fields: fields{ - port: ptr.To(uint16(1)), + port: new(uint16(1)), rs: mockRelayServerNotZeroVal(), }, args: args{ prefs: prefsWithPortOne.View(), sameNode: false, }, - wantPort: ptr.To(uint16(1)), + wantPort: new(uint16(1)), wantRelayServerFieldNonNil: true, wantRelayServerFieldMutated: true, }, { - name: "prefs port non-nil extension port nil", + name: "prefs-port-non-nil-ext-nil", fields: fields{ port: nil, }, @@ -198,7 +197,7 @@ func Test_extension_profileStateChanged(t *testing.T) { prefs: prefsWithPortOne.View(), sameNode: false, }, - wantPort: ptr.To(uint16(1)), + wantPort: new(uint16(1)), wantRelayServerFieldNonNil: true, wantRelayServerFieldMutated: true, }, @@ -278,41 +277,41 @@ func Test_extension_handleRelayServerLifetimeLocked(t *testing.T) { wantRelayServerFieldMutated bool }{ { - name: "want running", + name: "want-running", shutdown: false, - port: ptr.To(uint16(1)), + port: new(uint16(1)), hasNodeAttrDisableRelayServer: false, wantRelayServerFieldNonNil: true, wantRelayServerFieldMutated: true, }, { - name: "want running previously running", + name: "want-running-previously-running", shutdown: false, - port: ptr.To(uint16(1)), + port: new(uint16(1)), rs: mockRelayServerNotZeroVal(), hasNodeAttrDisableRelayServer: false, wantRelayServerFieldNonNil: true, wantRelayServerFieldMutated: false, }, { - name: "shutdown true", + name: "shutdown-true", shutdown: true, - port: ptr.To(uint16(1)), + port: new(uint16(1)), hasNodeAttrDisableRelayServer: false, wantRelayServerFieldNonNil: false, wantRelayServerFieldMutated: false, }, { - name: "shutdown true previously running", + name: "shutdown-true-previously-running", shutdown: true, - port: ptr.To(uint16(1)), + port: new(uint16(1)), rs: mockRelayServerNotZeroVal(), hasNodeAttrDisableRelayServer: false, wantRelayServerFieldNonNil: false, wantRelayServerFieldMutated: true, }, { - name: "port nil", + name: "port-nil", shutdown: false, port: nil, hasNodeAttrDisableRelayServer: false, @@ -320,7 +319,7 @@ func Test_extension_handleRelayServerLifetimeLocked(t *testing.T) { wantRelayServerFieldMutated: false, }, { - name: "port nil previously running", + name: "port-nil-previously-running", shutdown: false, port: nil, rs: mockRelayServerNotZeroVal(), @@ -329,7 +328,7 @@ func Test_extension_handleRelayServerLifetimeLocked(t *testing.T) { wantRelayServerFieldMutated: true, }, { - name: "hasNodeAttrDisableRelayServer true", + name: "hasNodeAttrDisableRelayServer-true", shutdown: false, port: nil, hasNodeAttrDisableRelayServer: true, @@ -337,7 +336,7 @@ func Test_extension_handleRelayServerLifetimeLocked(t *testing.T) { wantRelayServerFieldMutated: false, }, { - name: "hasNodeAttrDisableRelayServer true previously running", + name: "hasNodeAttrDisableRelayServer-true-running", shutdown: false, port: nil, rs: mockRelayServerNotZeroVal(), diff --git a/feature/ssh/ssh.go b/feature/ssh/ssh.go new file mode 100644 index 0000000000000..bd22005916d60 --- /dev/null +++ b/feature/ssh/ssh.go @@ -0,0 +1,11 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ((linux && !android) || (darwin && !ios) || freebsd || openbsd || plan9) && !ts_omit_ssh + +// Package ssh registers the Tailscale SSH feature, including host key +// management and the SSH server. +package ssh + +// Register implementations of various SSH hooks. +import _ "tailscale.com/ssh/tailssh" diff --git a/feature/taildrop/doc.go b/feature/taildrop/doc.go index c394ebe82e18a..a3243b3c2aa50 100644 --- a/feature/taildrop/doc.go +++ b/feature/taildrop/doc.go @@ -1,5 +1,10 @@ // Copyright (c) Tailscale Inc & contributors // SPDX-License-Identifier: BSD-3-Clause -// Package taildrop registers the taildrop (file sending) feature. +// Package taildrop contains the implementation of the Taildrop +// functionality including sending and retrieving files. +// This package does not validate permissions, the caller should +// be responsible for ensuring correct authorization. +// +// For related documentation see: http://go/taildrop-how-does-it-work package taildrop diff --git a/feature/taildrop/ext.go b/feature/taildrop/ext.go index 3a4ed456d2269..abf574ebc5407 100644 --- a/feature/taildrop/ext.go +++ b/feature/taildrop/ext.go @@ -139,8 +139,8 @@ func (e *Extension) onChangeProfile(profile ipn.LoginProfileView, _ ipn.PrefsVie e.mu.Lock() defer e.mu.Unlock() - uid := profile.UserProfile().ID - activeLogin := profile.UserProfile().LoginName + uid := profile.UserProfile().ID() + activeLogin := profile.UserProfile().LoginName() if uid == 0 { e.setMgrLocked(nil) diff --git a/feature/taildrop/fileops_fs.go b/feature/taildrop/fileops_fs.go index 4a5b3e71a0f55..3ddf95d0314cd 100644 --- a/feature/taildrop/fileops_fs.go +++ b/feature/taildrop/fileops_fs.go @@ -101,7 +101,7 @@ func (f fsFileOps) Rename(oldPath, newName string) (newPath string, err error) { wantSize := st.Size() const maxRetries = 10 - for i := 0; i < maxRetries; i++ { + for range maxRetries { renameMu.Lock() fi, statErr := os.Stat(dst) // Atomically rename the partial file as the destination file if it doesn't exist. diff --git a/feature/taildrop/taildrop.go b/feature/taildrop/taildrop.go index 7042ca97aa7ef..9839b8330e597 100644 --- a/feature/taildrop/taildrop.go +++ b/feature/taildrop/taildrop.go @@ -1,12 +1,6 @@ // Copyright (c) Tailscale Inc & contributors // SPDX-License-Identifier: BSD-3-Clause -// Package taildrop contains the implementation of the Taildrop -// functionality including sending and retrieving files. -// This package does not validate permissions, the caller should -// be responsible for ensuring correct authorization. -// -// For related documentation see: http://go/taildrop-how-does-it-work package taildrop import ( diff --git a/feature/tailnetlock/tailnetlock.go b/feature/tailnetlock/tailnetlock.go new file mode 100644 index 0000000000000..325a13b087bc0 --- /dev/null +++ b/feature/tailnetlock/tailnetlock.go @@ -0,0 +1,54 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// package tailnetlock registers the tailnet lock debug C2N handler. In the +// future, all tailnet lock code should move here. +package tailnetlock + +import ( + "fmt" + "net/http" + "strconv" + + "tailscale.com/cmd/tailscale/cli/jsonoutput" + "tailscale.com/feature" + "tailscale.com/feature/buildfeatures" + "tailscale.com/ipn/ipnlocal" +) + +func init() { + feature.Register("tailnetlock") + ipnlocal.RegisterC2N("/debug/tka/log", handleC2NDebugTKALog) +} + +const defaultC2NLogLimit = 50 +const maxC2NLogLimit = 1000 + +func handleC2NDebugTKALog(b *ipnlocal.LocalBackend, w http.ResponseWriter, r *http.Request) { + if !buildfeatures.HasDebug { + http.Error(w, feature.ErrUnavailable.Error(), http.StatusNotImplemented) + return + } + + logf := b.Logger() + logf("c2n: %s %s received", r.Method, r.URL) + + limit := defaultC2NLogLimit + limitStr := r.URL.Query().Get("limit") + if limitStr != "" { + if parsed, err := strconv.Atoi(limitStr); err == nil { + limit = min(parsed, maxC2NLogLimit) + } + } + + updates, err := b.NetworkLockLog(limit) + if ipnlocal.IsNetworkLockNotActive(err) { + http.Error(w, "tailnet lock not active", http.StatusBadRequest) + return + } else if err != nil { + http.Error(w, fmt.Sprintf("failed to get tailnet lock log: %v", err), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + jsonoutput.PrintNetworkLockLogJSONV1(w, updates) +} diff --git a/feature/tailnetlock/tailnetlock_test.go b/feature/tailnetlock/tailnetlock_test.go new file mode 100644 index 0000000000000..771525d9dd844 --- /dev/null +++ b/feature/tailnetlock/tailnetlock_test.go @@ -0,0 +1,146 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package tailnetlock + +import ( + "bytes" + "encoding/json" + "net/http/httptest" + "strings" + "testing" + + "tailscale.com/ipn/ipnlocal" + "tailscale.com/tka" + "tailscale.com/types/key" + "tailscale.com/util/must" +) + +func TestHandleC2NDebugTKA(t *testing.T) { + makeTKA := func(length int) (tka.CompactableChonk, *tka.Authority) { + if length == 0 { + return nil, nil + } + + disablementSecret := bytes.Repeat([]byte{0xa5}, 32) + signerKey := key.NewNLPrivate() + key1 := tka.Key{Kind: tka.Key25519, Public: signerKey.Public().Verifier(), Votes: 2} + + chonk := tka.ChonkMem() + authority, _, err := tka.Create(chonk, tka.State{ + Keys: []tka.Key{key1}, + DisablementValues: [][]byte{tka.DisablementKDF(disablementSecret)}, + }, signerKey) + if err != nil { + t.Fatalf("tka.Create() failed: %v", err) + } + + for range length - 1 { + updater := authority.NewUpdater(signerKey) + key2 := tka.Key{Kind: tka.Key25519, Public: key.NewNLPrivate().Public().Verifier(), Votes: 2} + updater.AddKey(key2) + aums := must.Get(updater.Finalize(chonk)) + must.Do(authority.Inform(chonk, aums)) + } + + return chonk, authority + } + + bodyHead := func(body *bytes.Buffer) string { + count := 0 + var sb strings.Builder + for line := range strings.Lines(body.String()) { + if count == 10 { + sb.WriteString("...") + break + } + sb.WriteString(line) + count++ + } + return sb.String() + } + + // matches [jsonoutput.PrintNetworkLockLogJSONV1] + type response struct { + SchemaVersion string + Messages []any + } + + t.Run("tailnet-lock-disabled", func(t *testing.T) { + b := ipnlocal.LocalBackendWithTKAForTest(nil, nil) + + req := httptest.NewRequest("GET", "/debug/tka/log", nil) + rec := httptest.NewRecorder() + b.HandleC2NForTest(rec, req) + + if rec.Code != 400 { + t.Fatalf("got status code: %v, want: 400\nBody: %s", rec.Code, rec.Body) + } + }) + + t.Run("tailnet-lock-enabled", func(t *testing.T) { + chonk, authority := makeTKA(2) + b := ipnlocal.LocalBackendWithTKAForTest(chonk, authority) + + req := httptest.NewRequest("GET", "/debug/tka/log", nil) + rec := httptest.NewRecorder() + b.HandleC2NForTest(rec, req) + + if rec.Code != 200 { + t.Fatalf("got status code: %v, want: 200\nBody: %s", rec.Code, bodyHead(rec.Body)) + } + + var got response + if err := json.Unmarshal(rec.Body.Bytes(), &got); err != nil { + t.Fatalf("couldn't parse JSON: %v\nbody: %s", err, bodyHead(rec.Body)) + } + + if len(got.Messages) != 2 { + t.Fatalf("got %d items, want 2", len(got.Messages)) + } + }) + + t.Run("default-limit", func(t *testing.T) { + chonk, authority := makeTKA(60) + b := ipnlocal.LocalBackendWithTKAForTest(chonk, authority) + + req := httptest.NewRequest("GET", "/debug/tka/log", nil) + rec := httptest.NewRecorder() + b.HandleC2NForTest(rec, req) + + if rec.Code != 200 { + t.Fatalf("got status code: %v, want: 200\nBody: %s", rec.Code, bodyHead(rec.Body)) + } + + var got response + if err := json.Unmarshal(rec.Body.Bytes(), &got); err != nil { + t.Fatalf("couldn't parse JSON: %v\nbody: %s", err, bodyHead(rec.Body)) + } + + if len(got.Messages) != 50 { + t.Fatalf("got %d items, want 50", len(got.Messages)) + } + }) + + t.Run("override-limit", func(t *testing.T) { + chonk, authority := makeTKA(65) + b := ipnlocal.LocalBackendWithTKAForTest(chonk, authority) + + req := httptest.NewRequest("GET", "/debug/tka/log?limit=60", nil) + rec := httptest.NewRecorder() + b.HandleC2NForTest(rec, req) + + if rec.Code != 200 { + t.Fatalf("got status code: %v, want: 200\nBody: %s", rec.Code, bodyHead(rec.Body)) + } + + var got response + if err := json.Unmarshal(rec.Body.Bytes(), &got); err != nil { + t.Fatalf("couldn't parse JSON: %v\nbody: %s", err, bodyHead(rec.Body)) + } + + if len(got.Messages) != 60 { + t.Fatalf("got %d items, want 60", len(got.Messages)) + } + }) +} diff --git a/feature/tundevstats/tundevstats_linux.go b/feature/tundevstats/tundevstats_linux.go new file mode 100644 index 0000000000000..13d5169c2a9ac --- /dev/null +++ b/feature/tundevstats/tundevstats_linux.go @@ -0,0 +1,442 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// Package tundevstats provides a mechanism for exposing TUN device statistics +// via clientmetrics. +package tundevstats + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + "runtime" + "sync" + "time" + "unsafe" + + "github.com/mdlayher/netlink" + "github.com/tailscale/wireguard-go/tun" + "golang.org/x/sys/unix" + "tailscale.com/feature" + "tailscale.com/net/tstun" + "tailscale.com/util/clientmetric" +) + +func init() { + feature.Register("tundevstats") + if runtime.GOOS != "linux" { + // Exclude Android for now. There's no reason this shouldn't work on + // Android, but it needs to be tested, and justified from a battery + // cost perspective. + return + } + tstun.HookPollTUNDevStats.Set(newPoller) +} + +// poller polls TUN device stats via netlink, and surfaces them via +// [tailscale.com/util/clientmetric]. +type poller struct { + conn *netlink.Conn + ifIndex uint32 + closeCh chan struct{} + closeOnce sync.Once + wg sync.WaitGroup + lastTXQDrops uint64 +} + +// getIfIndex returns the interface index for ifName via ioctl. +func getIfIndex(ifName string) (uint32, error) { + ifr, err := unix.NewIfreq(ifName) + if err != nil { + return 0, err + } + fd, err := unix.Socket( + unix.AF_INET, + unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, + 0, + ) + if err != nil { + return 0, err + } + defer unix.Close(fd) + err = unix.IoctlIfreq(fd, unix.SIOCGIFINDEX, ifr) + if err != nil { + return 0, err + } + return ifr.Uint32(), nil +} + +type netlinkDialFn func(family int, config *netlink.Config) (*netlink.Conn, error) + +// newPollerWithNetlinkDialer exists to allow swapping [netlinkDialFn] in tests, +// but newPoller, which calls with [netlink.Dial], is what gets set as a +// [feature.Hook] in tstun. +func newPollerWithNetlinkDialer(tdev tun.Device, netlinkDialFn netlinkDialFn) (io.Closer, error) { + ifName, err := tdev.Name() + if err != nil { + return nil, fmt.Errorf("error getting device name: %w", err) + } + ifIndex, err := getIfIndex(ifName) + if err != nil { + return nil, fmt.Errorf("error getting ifIndex: %w", err) + } + conn, err := netlinkDialFn(unix.NETLINK_ROUTE, nil) + if err != nil { + return nil, fmt.Errorf("error opening netlink socket: %w", err) + } + p := &poller{ + conn: conn, + ifIndex: ifIndex, + closeCh: make(chan struct{}), + } + p.wg.Go(p.run) + return p, nil +} + +// newPoller starts polling device stats for tdev, returning an [io.Closer] +// that halts polling operations. +func newPoller(tdev tun.Device) (io.Closer, error) { + return newPollerWithNetlinkDialer(tdev, netlink.Dial) +} + +const ( + // pollInterval is how frequently [poller] polls TUN device statistics. Its + // value mirrors [tailscale.com/util/clientmetric.minMetricEncodeInterval], + // which is the minimum interval between clientmetrics emissions. + pollInterval = 15 * time.Second +) + +var ( + registerMetricOnce sync.Once + txQueueDrops *clientmetric.Metric +) + +// getTXQDropsMetric returns the TX queue drops clientmetric. It must not be +// called until device stats have been successfully polled via netlink since it +// sets the metric value to zero. A nil or absent clientmetric has meaning when +// polling fails, vs a misleading zero value. +func getTXQDropsMetric() *clientmetric.Metric { + registerMetricOnce.Do(func() { + txQueueDrops = clientmetric.NewCounter("tundev_txq_drops") + }) + return txQueueDrops +} + +func (p *poller) poll() error { + stats, err := getStats(p.conn, p.ifIndex) + if err != nil { + return err + } + m := getTXQDropsMetric() + delta := stats.txDropped - p.lastTXQDrops + m.Add(int64(delta)) + p.lastTXQDrops = stats.txDropped + return nil +} + +// run polls immediately and every [pollInterval] returning when [poller.poll] +// returns an error, or [poller.closeCh] is closed via [poller.Close]. +func (p *poller) run() { + ticker := time.NewTicker(pollInterval) + defer ticker.Stop() + err := p.poll() // poll immediately + if err != nil { + return + } + for { + select { + case <-p.closeCh: + return + case <-ticker.C: + err = p.poll() + if err != nil { + return + } + } + } +} + +// Close halts polling operations. +func (p *poller) Close() error { + p.closeOnce.Do(func() { + p.conn.Close() + close(p.closeCh) + p.wg.Wait() + }) + return nil +} + +// ifStatsMsg is struct if_stats_msg from uapi/linux/if_link.h. +type ifStatsMsg struct { + family uint8 + pad1 uint8 + pad2 uint16 + ifIndex uint32 + filterMask uint32 +} + +// encode encodes i in binary form for use over netlink in an RTM_GETSTATS +// request. +func (i *ifStatsMsg) encode() []byte { + return unsafe.Slice((*byte)(unsafe.Pointer(i)), unsafe.Sizeof(ifStatsMsg{})) +} + +const ( + iflaStatsLink64 = 1 // IFLA_STATS_LINK_64 from uapi/linux/if_link.h + iflaStatsLink64FilterMask = 1 << (iflaStatsLink64 - 1) +) + +// getStats returns [rtnlLinkStats64] via netlink RTM_GETSTATS over the provided +// conn for the provided ifIndex. +func getStats(conn *netlink.Conn, ifIndex uint32) (rtnlLinkStats64, error) { + reqData := ifStatsMsg{ + family: unix.AF_UNSPEC, + ifIndex: ifIndex, + filterMask: iflaStatsLink64FilterMask, + } + req := netlink.Message{ + Header: netlink.Header{ + Flags: netlink.Request, + Type: unix.RTM_GETSTATS, + }, + Data: reqData.encode(), + } + msgs, err := conn.Execute(req) + if err != nil { + return rtnlLinkStats64{}, err + } + if len(msgs) != 1 { + return rtnlLinkStats64{}, fmt.Errorf("expected one netlink response message, got: %d", len(msgs)) + } + msg := msgs[0] + if msg.Header.Type != unix.RTM_NEWSTATS { + return rtnlLinkStats64{}, fmt.Errorf("expected RTM_NEWSTATS (%d) netlink response, got: %d", unix.RTM_NEWSTATS, msg.Header.Type) + } + sizeOfIfStatsMsg := int(unsafe.Sizeof(ifStatsMsg{})) + if len(msg.Data) < sizeOfIfStatsMsg { + return rtnlLinkStats64{}, fmt.Errorf("length of netlink response data < %d, got: %d", sizeOfIfStatsMsg, len(msg.Data)) + } + ad, err := netlink.NewAttributeDecoder(msg.Data[sizeOfIfStatsMsg:]) + if err != nil { + return rtnlLinkStats64{}, err + } + for ad.Next() { + if ad.Type() == iflaStatsLink64 { + stats := rtnlLinkStats64{} + ad.Do(func(b []byte) error { + return stats.decode(b) + }) + if ad.Err() != nil { + return rtnlLinkStats64{}, ad.Err() + } + return stats, nil + } + } + if err = ad.Err(); err != nil { + return rtnlLinkStats64{}, err + } + return rtnlLinkStats64{}, errors.New("no stats found in netlink response") +} + +// rtnlLinkStats64 is struct rtnl_link_stats64 from uapi/linux/if_link.h up to +// the addition of the RTM_GETSTATS netlink message (Linux commit 10c9ead9f3c6). +// Newer fields are omitted. Since we expect this type in response to RTM_GETSTATS, +// we marry them together from a minimum kernel version perspective (Linux v4.7). +// Field documentation is copied from the kernel verbatim. +type rtnlLinkStats64 struct { + // rxPackets is the number of good packets received by the interface. + // For hardware interfaces counts all good packets received from the device + // by the host, including packets which host had to drop at various stages + // of processing (even in the driver). + rxPackets uint64 + + // txPackets is the number of packets successfully transmitted. + // For hardware interfaces counts packets which host was able to successfully + // hand over to the device, which does not necessarily mean that packets + // had been successfully transmitted out of the device, only that device + // acknowledged it copied them out of host memory. + txPackets uint64 + + // rxBytes is the number of good received bytes, corresponding to rxPackets. + // For IEEE 802.3 devices should count the length of Ethernet Frames + // excluding the FCS. + rxBytes uint64 + + // txBytes is the number of good transmitted bytes, corresponding to txPackets. + // For IEEE 802.3 devices should count the length of Ethernet Frames + // excluding the FCS. + txBytes uint64 + + // rxErrors is the total number of bad packets received on this network device. + // This counter must include events counted by rxLengthErrors, + // rxCRCErrors, rxFrameErrors and other errors not otherwise counted. + rxErrors uint64 + + // txErrors is the total number of transmit problems. + // This counter must include events counted by txAbortedErrors, + // txCarrierErrors, txFIFOErrors, txHeartbeatErrors, + // txWindowErrors and other errors not otherwise counted. + txErrors uint64 + + // rxDropped is the number of packets received but not processed, + // e.g. due to lack of resources or unsupported protocol. + // For hardware interfaces this counter may include packets discarded + // due to L2 address filtering but should not include packets dropped + // by the device due to buffer exhaustion which are counted separately in + // rxMissedErrors (since procfs folds those two counters together). + rxDropped uint64 + + // txDropped is the number of packets dropped on their way to transmission, + // e.g. due to lack of resources. + txDropped uint64 + + // multicast is the number of multicast packets received. + // For hardware interfaces this statistic is commonly calculated + // at the device level (unlike rxPackets) and therefore may include + // packets which did not reach the host. + // For IEEE 802.3 devices this counter may be equivalent to: + // - 30.3.1.1.21 aMulticastFramesReceivedOK + multicast uint64 + + // collisions is the number of collisions during packet transmissions. + collisions uint64 + + // rxLengthErrors is the number of packets dropped due to invalid length. + // Part of aggregate "frame" errors in /proc/net/dev. + // For IEEE 802.3 devices this counter should be equivalent to a sum of: + // - 30.3.1.1.23 aInRangeLengthErrors + // - 30.3.1.1.24 aOutOfRangeLengthField + // - 30.3.1.1.25 aFrameTooLongErrors + rxLengthErrors uint64 + + // rxOverErrors is the receiver FIFO overflow event counter. + // Historically the count of overflow events. Such events may be reported + // in the receive descriptors or via interrupts, and may not correspond + // one-to-one with dropped packets. + // The recommended interpretation for high speed interfaces is the number + // of packets dropped because they did not fit into buffers provided by the + // host, e.g. packets larger than MTU or next buffer in the ring was not + // available for a scatter transfer. + // Part of aggregate "frame" errors in /proc/net/dev. + // This statistic corresponds to hardware events and is not commonly used + // on software devices. + rxOverErrors uint64 + + // rxCRCErrors is the number of packets received with a CRC error. + // Part of aggregate "frame" errors in /proc/net/dev. + // For IEEE 802.3 devices this counter must be equivalent to: + // - 30.3.1.1.6 aFrameCheckSequenceErrors + rxCRCErrors uint64 + + // rxFrameErrors is the receiver frame alignment errors. + // Part of aggregate "frame" errors in /proc/net/dev. + // For IEEE 802.3 devices this counter should be equivalent to: + // - 30.3.1.1.7 aAlignmentErrors + rxFrameErrors uint64 + + // rxFIFOErrors is the receiver FIFO error counter. + // Historically the count of overflow events. Those events may be reported + // in the receive descriptors or via interrupts, and may not correspond + // one-to-one with dropped packets. + // This statistic is used on software devices, e.g. to count software + // packet queue overflow (can) or sequencing errors (GRE). + rxFIFOErrors uint64 + + // rxMissedErrors is the count of packets missed by the host. + // Folded into the "drop" counter in /proc/net/dev. + // Counts number of packets dropped by the device due to lack of buffer + // space. This usually indicates that the host interface is slower than + // the network interface, or host is not keeping up with the receive + // packet rate. + // This statistic corresponds to hardware events and is not used on + // software devices. + rxMissedErrors uint64 + + // txAbortedErrors is part of aggregate "carrier" errors in /proc/net/dev. + // For IEEE 802.3 devices capable of half-duplex operation this counter + // must be equivalent to: + // - 30.3.1.1.11 aFramesAbortedDueToXSColls + // High speed interfaces may use this counter as a general device discard + // counter. + txAbortedErrors uint64 + + // txCarrierErrors is the number of frame transmission errors due to loss + // of carrier during transmission. + // Part of aggregate "carrier" errors in /proc/net/dev. + // For IEEE 802.3 devices this counter must be equivalent to: + // - 30.3.1.1.13 aCarrierSenseErrors + txCarrierErrors uint64 + + // txFIFOErrors is the number of frame transmission errors due to device + // FIFO underrun / underflow. This condition occurs when the device begins + // transmission of a frame but is unable to deliver the entire frame to + // the transmitter in time for transmission. + // Part of aggregate "carrier" errors in /proc/net/dev. + txFIFOErrors uint64 + + // txHeartbeatErrors is the number of Heartbeat / SQE Test errors for + // old half-duplex Ethernet. + // Part of aggregate "carrier" errors in /proc/net/dev. + // For IEEE 802.3 devices possibly equivalent to: + // - 30.3.2.1.4 aSQETestErrors + txHeartbeatErrors uint64 + + // txWindowErrors is the number of frame transmission errors due to late + // collisions (for Ethernet - after the first 64B of transmission). + // Part of aggregate "carrier" errors in /proc/net/dev. + // For IEEE 802.3 devices this counter must be equivalent to: + // - 30.3.1.1.10 aLateCollisions + txWindowErrors uint64 + + // rxCompressed is the number of correctly received compressed packets. + // This counter is only meaningful for interfaces which support packet + // compression (e.g. CSLIP, PPP). + rxCompressed uint64 + + // txCompressed is the number of transmitted compressed packets. + // This counter is only meaningful for interfaces which support packet + // compression (e.g. CSLIP, PPP). + txCompressed uint64 + + // rxNoHandler is the number of packets received on the interface but + // dropped by the networking stack because the device is not designated + // to receive packets (e.g. backup link in a bond). + rxNoHandler uint64 +} + +// decode unpacks a [rtnlLinkStats64] from the raw bytes of a netlink attribute +// payload, e.g. IFLA_STATS_LINK_64. The kernel writes the struct in host byte +// order, so binary.NativeEndian is used throughout. The buffer may be larger +// than the struct to allow for future kernel additions. +func (s *rtnlLinkStats64) decode(b []byte) error { + const minSize = 24 * 8 + if len(b) < minSize { + return fmt.Errorf("rtnlLinkStats64.decode: buffer too short: got %d bytes, want at least %d", len(b), minSize) + } + s.rxPackets = binary.NativeEndian.Uint64(b[0:]) + s.txPackets = binary.NativeEndian.Uint64(b[8:]) + s.rxBytes = binary.NativeEndian.Uint64(b[16:]) + s.txBytes = binary.NativeEndian.Uint64(b[24:]) + s.rxErrors = binary.NativeEndian.Uint64(b[32:]) + s.txErrors = binary.NativeEndian.Uint64(b[40:]) + s.rxDropped = binary.NativeEndian.Uint64(b[48:]) + s.txDropped = binary.NativeEndian.Uint64(b[56:]) + s.multicast = binary.NativeEndian.Uint64(b[64:]) + s.collisions = binary.NativeEndian.Uint64(b[72:]) + s.rxLengthErrors = binary.NativeEndian.Uint64(b[80:]) + s.rxOverErrors = binary.NativeEndian.Uint64(b[88:]) + s.rxCRCErrors = binary.NativeEndian.Uint64(b[96:]) + s.rxFrameErrors = binary.NativeEndian.Uint64(b[104:]) + s.rxFIFOErrors = binary.NativeEndian.Uint64(b[112:]) + s.rxMissedErrors = binary.NativeEndian.Uint64(b[120:]) + s.txAbortedErrors = binary.NativeEndian.Uint64(b[128:]) + s.txCarrierErrors = binary.NativeEndian.Uint64(b[136:]) + s.txFIFOErrors = binary.NativeEndian.Uint64(b[144:]) + s.txHeartbeatErrors = binary.NativeEndian.Uint64(b[152:]) + s.txWindowErrors = binary.NativeEndian.Uint64(b[160:]) + s.rxCompressed = binary.NativeEndian.Uint64(b[168:]) + s.txCompressed = binary.NativeEndian.Uint64(b[176:]) + s.rxNoHandler = binary.NativeEndian.Uint64(b[184:]) + return nil +} diff --git a/feature/tundevstats/tundevstats_linux_test.go b/feature/tundevstats/tundevstats_linux_test.go new file mode 100644 index 0000000000000..05468039beb17 --- /dev/null +++ b/feature/tundevstats/tundevstats_linux_test.go @@ -0,0 +1,105 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package tundevstats + +import ( + "encoding/binary" + "os" + "sync/atomic" + "testing" + "testing/synctest" + "time" + "unsafe" + + "github.com/mdlayher/netlink" + "github.com/mdlayher/netlink/nltest" + "github.com/tailscale/wireguard-go/tun" + "golang.org/x/sys/unix" +) + +func Test_getIfIndex(t *testing.T) { + ifIndex, err := getIfIndex("lo") + if err != nil { + t.Fatal(err) + } + if ifIndex != 1 { + // loopback ifIndex is effectively always 1 on Linux, see + // LOOPBACK_IFINDEX in the kernel (net/flow.h). + t.Fatalf("expected ifIndex of 1 for loopback, got: %d", ifIndex) + } +} + +type fakeDevice struct { + name string +} + +func (f *fakeDevice) File() *os.File { return nil } +func (f *fakeDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { return 0, nil } +func (f *fakeDevice) Write(bufs [][]byte, offset int) (int, error) { return 0, nil } +func (f *fakeDevice) MTU() (int, error) { return 0, nil } +func (f *fakeDevice) Name() (string, error) { return f.name, nil } +func (f *fakeDevice) Events() <-chan tun.Event { return nil } +func (f *fakeDevice) Close() error { return nil } +func (f *fakeDevice) BatchSize() int { return 0 } + +func Test_poller(t *testing.T) { + getTXQDropsMetric().Set(0) // reset for test count > 1 + + var drops atomic.Uint64 + // dial is a [nltest.Func] that returns an RTM_NEWSTATS response with [drops] + // at the txDropped offset within the [rtnlLinkStats64] attribute payload. + dial := func(req []netlink.Message) ([]netlink.Message, error) { + if len(req) != 1 { + t.Fatalf("unexpected number of netlink request messages: %d", len(req)) + } + if req[0].Header.Type != unix.RTM_GETSTATS { + t.Fatalf("unexpected netlink request message type: %d want: %d", req[0].Header.Type, unix.RTM_GETSTATS) + } + data := make([]byte, unsafe.Sizeof(ifStatsMsg{})) + ae := netlink.NewAttributeEncoder() + ae.Do(iflaStatsLink64, func() ([]byte, error) { + ret := make([]byte, unsafe.Sizeof(rtnlLinkStats64{})) + binary.NativeEndian.PutUint64(ret[56:], drops.Load()) + return ret, nil + }) + attrs, err := ae.Encode() + if err != nil { + t.Fatal(err) + } + data = append(data, attrs...) + return []netlink.Message{ + { + Header: netlink.Header{ + Type: unix.RTM_NEWSTATS, + Sequence: req[0].Header.Sequence, + }, + Data: data, + }, + }, nil + } + + lo := &fakeDevice{name: "lo"} + drops.Store(1) + synctest.Test(t, func(t *testing.T) { + closer, err := newPollerWithNetlinkDialer(lo, func(family int, config *netlink.Config) (*netlink.Conn, error) { + return nltest.Dial(dial), nil + }) + if err != nil { + t.Fatal(err) + } + synctest.Wait() // first poll complete, poller.run() durably blocked in select + if got := getTXQDropsMetric().Value(); got != 1 { + t.Errorf("got drops: %d want: %d", got, 1) + } + drops.Store(2) // increment drops to 2 + time.Sleep(pollInterval) + synctest.Wait() // second poll complete, poller.run() durably blocked in select again + if got := getTXQDropsMetric().Value(); got != 2 { + t.Errorf("got drops: %d want: %d", got, 2) + } + closer.Close() + closer.Close() // multiple calls to Close() shouldn't panic + }) + +} diff --git a/flake.lock b/flake.lock index 1623342c62407..243188e431835 100644 --- a/flake.lock +++ b/flake.lock @@ -3,11 +3,11 @@ "flake-compat": { "flake": false, "locked": { - "lastModified": 1696426674, - "narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=", + "lastModified": 1767039857, + "narHash": "sha256-vNpUSpF5Nuw8xvDLj2KCwwksIbjua2LZCqhV1LNRDns=", "owner": "edolstra", "repo": "flake-compat", - "rev": "0f9255e01c2351cc7d116c072cb317785dd33b33", + "rev": "5edf11c44bc78a0d334f6334cdaf7d60d732daab", "type": "github" }, "original": { @@ -18,11 +18,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1753151930, - "narHash": "sha256-XSQy6wRKHhRe//iVY5lS/ZpI/Jn6crWI8fQzl647wCg=", + "lastModified": 1772736753, + "narHash": "sha256-au/m3+EuBLoSzWUCb64a/MZq6QUtOV8oC0D9tY2scPQ=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "83e677f31c84212343f4cc553bab85c2efcad60a", + "rev": "917fec990948658ef1ccd07cef2a1ef060786846", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index c9e3b50a1ad73..9dbb5abbefab0 100644 --- a/flake.nix +++ b/flake.nix @@ -48,14 +48,15 @@ }: let goVersion = nixpkgs.lib.fileContents ./go.toolchain.version; toolChainRev = nixpkgs.lib.fileContents ./go.toolchain.rev; - gitHash = nixpkgs.lib.fileContents ./go.toolchain.rev.sri; + flakeHashes = builtins.fromJSON (builtins.readFile ./flakehashes.json); + gitHash = flakeHashes.toolchain.sri; eachSystem = f: nixpkgs.lib.genAttrs (import systems) (system: f (import nixpkgs { system = system; overlays = [ (final: prev: { - go_1_26 = prev.go_1_26.overrideAttrs { + go_1_26 = prev.go_1_26.overrideAttrs (old: { version = goVersion; src = prev.fetchFromGitHub { owner = "tailscale"; @@ -63,7 +64,19 @@ rev = toolChainRev; sha256 = gitHash; }; - }; + # The Tailscale Go fork carries a placeholder in + # src/runtime/debug/mod.go that must be replaced with + # the actual toolchain git rev at build time. Without + # this, binaries report an empty tailscale.toolchain.rev + # and the runtime assertion in + # assert_ts_toolchain_match.go panics. + postPatch = + (old.postPatch or "") + + '' + substituteInPlace src/runtime/debug/mod.go \ + --replace-fail "TAILSCALE_GIT_REV_TO_BE_REPLACED_AT_BUILD_TIME" "${toolChainRev}" + ''; + }); }) ]; })); @@ -87,11 +100,11 @@ # you're an end user you should be prepared for this flake to not # build periodically. packages = eachSystem (pkgs: rec { - default = pkgs.buildGo125Module { + default = pkgs.buildGo126Module { name = "tailscale"; pname = "tailscale"; src = ./.; - vendorHash = pkgs.lib.fileContents ./go.mod.sri; + vendorHash = flakeHashes.vendor.sri; nativeBuildInputs = [pkgs.makeWrapper pkgs.installShellFiles]; ldflags = ["-X tailscale.com/version.gitCommitStamp=${tailscaleRev}"]; env.CGO_ENABLED = 0; @@ -151,4 +164,4 @@ }); }; } -# nix-direnv cache busting line: sha256-rhuWEEN+CtumVxOw6Dy/IRxWIrZ2x6RJb6ULYwXCQc4= +# nix-direnv cache busting line: sha256-mbxLXR2TBgiwyVGfLmMR5xWk+0f66mPDas95Wla70Lk= diff --git a/flakehashes.json b/flakehashes.json new file mode 100644 index 0000000000000..9ee6ccb99400e --- /dev/null +++ b/flakehashes.json @@ -0,0 +1,10 @@ +{ + "toolchain": { + "rev": "e877d973840c91ec9d4bc1921b0845789de359ae", + "sri": "sha256-HeD70CytKL0Ks/VDqMU73bN8fxpWkNc6mNgNr9PEO7k=" + }, + "vendor": { + "goModSum": "sha256-IbxUmMBapp3G2WIK+gqfmQd1tLCVoHMYBHLPZ5ZjDIU=", + "sri": "sha256-mbxLXR2TBgiwyVGfLmMR5xWk+0f66mPDas95Wla70Lk=" + } +} diff --git a/go.mod.sri b/go.mod.sri deleted file mode 100644 index a307075942f64..0000000000000 --- a/go.mod.sri +++ /dev/null @@ -1 +0,0 @@ -sha256-rhuWEEN+CtumVxOw6Dy/IRxWIrZ2x6RJb6ULYwXCQc4= diff --git a/go.sum b/go.sum index b61f1d24a1db1..295ad3aed41df 100644 --- a/go.sum +++ b/go.sum @@ -205,8 +205,8 @@ github.com/bombsimon/wsl/v4 v4.2.1 h1:Cxg6u+XDWff75SIFFmNsqnIOgob+Q9hG6y/ioKbRFi github.com/bombsimon/wsl/v4 v4.2.1/go.mod h1:Xu/kDxGZTofQcDGCtQe9KCzhHphIe0fDuyWTxER9Feo= github.com/bradfitz/go-tool-cache v0.0.0-20260216153636-9e5201344fe5 h1:0sG3c7afYdBNlc3QyhckvZ4bV9iqlfqCQM1i+mWm0eE= github.com/bradfitz/go-tool-cache v0.0.0-20260216153636-9e5201344fe5/go.mod h1:78ZLITnBUCDJeU01+wYYJKaPYYgsDzJPRfxeI8qFh5g= -github.com/bradfitz/monogok v0.0.0-20260208031948-2219c393d032 h1:xDomVqO85ss/98Ky5zxM/g86bXDNBLebM2I9G/fu6uA= -github.com/bradfitz/monogok v0.0.0-20260208031948-2219c393d032/go.mod h1:TG1HbU9fRVDnNgXncVkKz9GdvjIvqquXjH6QZSEVmY4= +github.com/bradfitz/monogok v0.0.0-20260429173803-229ef7981a6b h1:lhWZfi1U/yi8zuFA6pkJKYv45pVAC3xs6SUE2QsjsEE= +github.com/bradfitz/monogok v0.0.0-20260429173803-229ef7981a6b/go.mod h1:TG1HbU9fRVDnNgXncVkKz9GdvjIvqquXjH6QZSEVmY4= github.com/bramvdbogaerde/go-scp v1.4.0 h1:jKMwpwCbcX1KyvDbm/PDJuXcMuNVlLGi0Q0reuzjyKY= github.com/bramvdbogaerde/go-scp v1.4.0/go.mod h1:on2aH5AxaFb2G0N5Vsdy6B0Ml7k9HuHSwfo1y0QzAbQ= github.com/breml/bidichk v0.2.7 h1:dAkKQPLl/Qrk7hnP6P+E0xOodrq8Us7+U0o4UBOAlQY= @@ -268,8 +268,8 @@ github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo= github.com/containerd/platforms v1.0.0-rc.2 h1:0SPgaNZPVWGEi4grZdV8VRYQn78y+nm6acgLGv/QzE4= github.com/containerd/platforms v1.0.0-rc.2/go.mod h1:J71L7B+aiM5SdIEqmd9wp6THLVRzJGXfNuWCZCllLA4= -github.com/containerd/stargz-snapshotter/estargz v0.18.1 h1:cy2/lpgBXDA3cDKSyEfNOFMA/c10O1axL69EU7iirO8= -github.com/containerd/stargz-snapshotter/estargz v0.18.1/go.mod h1:ALIEqa7B6oVDsrF37GkGN20SuvG/pIMm7FwP7ZmRb0Q= +github.com/containerd/stargz-snapshotter/estargz v0.18.2 h1:yXkZFYIzz3eoLwlTUZKz2iQ4MrckBxJjkmD16ynUTrw= +github.com/containerd/stargz-snapshotter/estargz v0.18.2/go.mod h1:XyVU5tcJ3PRpkA9XS2T5us6Eg35yM0214Y+wvrZTBrY= github.com/containerd/typeurl/v2 v2.2.3 h1:yNA/94zxWdvYACdYO8zofhrTVuQY73fFU1y++dYSw40= github.com/containerd/typeurl/v2 v2.2.3/go.mod h1:95ljDnPfD3bAbDJRugOiShd/DlAAsxGtUBhJxIn7SCk= github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6 h1:8h5+bWd7R6AYUslN6c6iuZWTKsKxUFDlpnmilO6R2n0= @@ -319,16 +319,12 @@ github.com/djherbis/times v1.6.0 h1:w2ctJ92J8fBvWPxugmXIv7Nz7Q3iDMKNx9v5ocVH20c= github.com/djherbis/times v1.6.0/go.mod h1:gOHeRAz2h+VJNZ5Gmc/o7iD9k4wW7NMVqieYCY99oc0= github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI= github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= -github.com/docker/cli v29.0.3+incompatible h1:8J+PZIcF2xLd6h5sHPsp5pvvJA+Sr2wGQxHkRl53a1E= -github.com/docker/cli v29.0.3+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8= -github.com/docker/distribution v2.8.3+incompatible h1:AtKxIZ36LoNK51+Z6RpzLpddBirtxJnzDrHLEKxTAYk= -github.com/docker/distribution v2.8.3+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w= -github.com/docker/docker v28.5.2+incompatible h1:DBX0Y0zAjZbSrm1uzOkdr1onVghKaftjlSWt4AFexzM= -github.com/docker/docker v28.5.2+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= +github.com/docker/cli v29.4.0+incompatible h1:+IjXULMetlvWJiuSI0Nbor36lcJ5BTcVpUmB21KBoVM= +github.com/docker/cli v29.4.0+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8= github.com/docker/docker-credential-helpers v0.9.3 h1:gAm/VtF9wgqJMoxzT3Gj5p4AqIjCBS4wrsOh9yRqcz8= github.com/docker/docker-credential-helpers v0.9.3/go.mod h1:x+4Gbw9aGmChi3qTLZj8Dfn0TD20M/fuWy0E5+WDeCo= -github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c= -github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc= +github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94= +github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE= github.com/docker/go-events v0.0.0-20250808211157-605354379745 h1:yOn6Ze6IbYI/KAw2lw/83ELYvZh6hvsygTVkD0dzMC4= github.com/docker/go-events v0.0.0-20250808211157-605354379745/go.mod h1:Uw6UezgYA44ePAFQYUehOuCzmy5zmg/+nl2ZfMWGkpA= github.com/docker/go-metrics v0.0.1 h1:AgB/0SvBxihN0X8OR4SjsblXkbMvalQ8cjmtKQ2rQV8= @@ -396,12 +392,12 @@ github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxI github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og= github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 h1:+zs/tPmkDkHx3U66DAb0lQFJrpS6731Oaa12ikc+DiI= github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376/go.mod h1:an3vInlBmSxCcxctByoQdvwPiA7DTK7jaaFDBTtu0ic= -github.com/go-git/go-billy/v5 v5.6.2 h1:6Q86EsPXMa7c3YZ3aLAQsMA0VlWmy43r6FHqa/UNbRM= -github.com/go-git/go-billy/v5 v5.6.2/go.mod h1:rcFC2rAsp/erv7CMz9GczHcuD0D32fWzH+MJAU+jaUU= +github.com/go-git/go-billy/v5 v5.8.0 h1:I8hjc3LbBlXTtVuFNJuwYuMiHvQJDq1AT6u4DwDzZG0= +github.com/go-git/go-billy/v5 v5.8.0/go.mod h1:RpvI/rw4Vr5QA+Z60c6d6LXH0rYJo0uD5SqfmrrheCY= github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399 h1:eMje31YglSBqCdIqdhKBW8lokaMrL3uTkpGYlE2OOT4= github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399/go.mod h1:1OCfN199q1Jm3HZlxleg+Dw/mwps2Wbk9frAWm+4FII= -github.com/go-git/go-git/v5 v5.16.5 h1:mdkuqblwr57kVfXri5TTH+nMFLNUxIj9Z7F5ykFbw5s= -github.com/go-git/go-git/v5 v5.16.5/go.mod h1:QOMLpNf1qxuSY4StA/ArOdfFR2TrKEjJiye2kel2m+M= +github.com/go-git/go-git/v5 v5.17.1 h1:WnljyxIzSj9BRRUlnmAU35ohDsjRK0EKmL0evDqi5Jk= +github.com/go-git/go-git/v5 v5.17.1/go.mod h1:pW/VmeqkanRFqR6AljLcs7EA7FbZaN5MQqO7oZADXpo= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= @@ -487,11 +483,13 @@ github.com/gokrazy/breakglass v0.0.0-20251229072214-9dbc0478d486/go.mod h1:PFPkR github.com/gokrazy/gokapi v0.0.0-20250222071133-506fdb322775 h1:f5+2UMRRbr3+e/gdWCBNn48chS/KMMljfbmlSSHfRBA= github.com/gokrazy/gokapi v0.0.0-20250222071133-506fdb322775/go.mod h1:q9mIV8al0wqmqFXJhKiO3SOHkL9/7Q4kIMynqUQWhgU= github.com/gokrazy/gokrazy v0.0.0-20200501080617-f3445e01a904/go.mod h1:pq6rGHqxMRPSaTXaCMzIZy0wLDusAJyoVNyNo05RLs0= -github.com/gokrazy/gokrazy v0.0.0-20260123094004-294c93fa173c h1:grjqEMf6dPJzZxf+gdo8rjx6bcyseO5p9hierlVkhXQ= -github.com/gokrazy/gokrazy v0.0.0-20260123094004-294c93fa173c/go.mod h1:NtMkrFeDGnwldKLi0dLdd2ipNwoVa7TI4HTxsy7lFRg= +github.com/gokrazy/gokrazy v0.0.0-20260418085648-c38c3134b8a7 h1:Isk3pOiVO5uj4BSrfRlQ16v6YpelnrTgMC618hEkKJ8= +github.com/gokrazy/gokrazy v0.0.0-20260418085648-c38c3134b8a7/go.mod h1:NtMkrFeDGnwldKLi0dLdd2ipNwoVa7TI4HTxsy7lFRg= github.com/gokrazy/internal v0.0.0-20200407075822-660ad467b7c9/go.mod h1:LA5TQy7LcvYGQOy75tkrYkFUhbV2nl5qEBP47PSi2JA= github.com/gokrazy/internal v0.0.0-20251208203110-3c1aa9087c82 h1:4ghNfD9NaZLpFrqQiBF6mPVFeMYXJSky38ubVA4ic2E= github.com/gokrazy/internal v0.0.0-20251208203110-3c1aa9087c82/go.mod h1:dQY4EMkD4L5ZjYJ0SPtpgYbV7MIUMCxNIXiOfnZ6jP4= +github.com/gokrazy/kernel.arm64 v0.0.0-20260403054012-807489e0272a h1:fa11POmSLo6fkkcqc+RUIyiqGJzBAOHEe/CCHAA/NGc= +github.com/gokrazy/kernel.arm64 v0.0.0-20260403054012-807489e0272a/go.mod h1:WWx72LXHEesuJxbopusRfSoKJQ6ffdwkT0DZditdrLo= github.com/gokrazy/serial-busybox v0.0.0-20250119153030-ac58ba7574e7 h1:gurTGc4sL7Ik+IKZ29rhGgHNZQTXPtEXLw+aM9E+/HE= github.com/gokrazy/serial-busybox v0.0.0-20250119153030-ac58ba7574e7/go.mod h1:OYcG5tSb+QrelmUOO4EZVUFcIHyyZb0QDbEbZFUp1TA= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g= @@ -564,8 +562,8 @@ github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/google/go-containerregistry v0.20.7 h1:24VGNpS0IwrOZ2ms2P1QE3Xa5X9p4phx0aUgzYzHW6I= -github.com/google/go-containerregistry v0.20.7/go.mod h1:Lx5LCZQjLH1QBaMPeGwsME9biPeo1lPx6lbGj/UmzgM= +github.com/google/go-containerregistry v0.21.5 h1:KTJG9Pn/jC0VdZR6ctV3/jcN+q6/Iqlx0sTVz3ywZlM= +github.com/google/go-containerregistry v0.21.5/go.mod h1:ySvMuiWg+dOsRW0Hw8GYwfMwBlNRTmpYBFJPlkco5zU= github.com/google/go-github/v66 v66.0.0 h1:ADJsaXj9UotwdgK8/iFZtv7MLc8E8WBl62WLd/D/9+M= github.com/google/go-github/v66 v66.0.0/go.mod h1:+4SO9Zkuyf8ytMj0csN1NR/5OTR+MfqPp8P8dVlcvY4= github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= @@ -596,8 +594,8 @@ github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hf github.com/google/pprof v0.0.0-20241029153458-d1b30febd7db h1:097atOisP2aRj7vFgYQBbFN4U4JNXUNYpxael3UzMyo= github.com/google/pprof v0.0.0-20241029153458-d1b30febd7db/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= -github.com/google/renameio/v2 v2.0.0 h1:UifI23ZTGY8Tt29JbYFiuyIU3eX+RNFtUwefq9qAhxg= -github.com/google/renameio/v2 v2.0.0/go.mod h1:BtmJXm5YlszgC+TD4HOEEUFgkJP3nLxehU6hfe7jRt4= +github.com/google/renameio/v2 v2.0.2 h1:qKZs+tfn+arruZZhQ7TKC/ergJunuJicWS6gLDt/dGw= +github.com/google/renameio/v2 v2.0.2/go.mod h1:OX+G6WHHpHq3NVj7cAOleLOwJfcQ1s3uUJQCrr78SWo= github.com/google/rpmpack v0.5.0 h1:L16KZ3QvkFGpYhmp23iQip+mx1X39foEsqszjMNBm8A= github.com/google/rpmpack v0.5.0/go.mod h1:uqVAUVQLq8UY2hCDfmJ/+rtO3aw7qyhc90rCVEabEfI= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= @@ -753,6 +751,8 @@ github.com/karamaru-alpha/copyloopvar v1.0.8 h1:gieLARwuByhEMxRwM3GRS/juJqFbLraf github.com/karamaru-alpha/copyloopvar v1.0.8/go.mod h1:u7CIfztblY0jZLOQZgH3oYsJzpC2A7S6u/lfgSXHy0k= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= +github.com/kdomanski/iso9660 v0.4.0 h1:BPKKdcINz3m0MdjIMwS0wx1nofsOjxOq8TOr45WGHFg= +github.com/kdomanski/iso9660 v0.4.0/go.mod h1:OxUSupHsO9ceI8lBLPJKWBTphLemjrCQY8LPXM7qSzU= github.com/kenshaw/evdev v0.1.0 h1:wmtceEOFfilChgdNT+c/djPJ2JineVsQ0N14kGzFRUo= github.com/kenshaw/evdev v0.1.0/go.mod h1:B/fErKCihUyEobz0mjn2qQbHgyJKFQAxkXSvkeeA/Wo= github.com/kevinburke/ssh_config v1.2.0 h1:x584FjTGwHzMwvHx18PXxbBVzfnxogHaAReU4gf13a4= @@ -763,8 +763,8 @@ github.com/kisielk/errcheck v1.7.0/go.mod h1:1kLL+jV4e+CFfueBmI1dSK2ADDyQnlrnrY/ github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/kkHAIKE/contextcheck v1.1.4 h1:B6zAaLhOEEcjvUgIYEqystmnFk1Oemn8bvJhbt0GMb8= github.com/kkHAIKE/contextcheck v1.1.4/go.mod h1:1+i/gWqokIa+dm31mqGLZhZJ7Uh44DJGZVmr6QRBNJg= -github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk= -github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= +github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE= +github.com/klauspost/compress v1.18.5/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ= github.com/klauspost/pgzip v1.2.6 h1:8RXeL5crjEUFnR2/Sn6GJNWtSQ3Dk8pq4CL3jvdDyjU= github.com/klauspost/pgzip v1.2.6/go.mod h1:Ch1tH69qFZu15pkjo5kYi6mth2Zzwzt50oCQKQE9RUs= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= @@ -871,12 +871,12 @@ github.com/moby/buildkit v0.20.2 h1:qIeR47eQ1tzI1rwz0on3Xx2enRw/1CKjFhoONVcTlMA= github.com/moby/buildkit v0.20.2/go.mod h1:DhaF82FjwOElTftl0JUAJpH/SUIUx4UvcFncLeOtlDI= github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo= +github.com/moby/moby/api v1.54.1 h1:TqVzuJkOLsgLDDwNLmYqACUuTehOHRGKiPhvH8V3Nn4= +github.com/moby/moby/api v1.54.1/go.mod h1:+RQ6wluLwtYaTd1WnPLykIDPekkuyD/ROWQClE83pzs= +github.com/moby/moby/client v0.4.0 h1:S+2XegzHQrrvTCvF6s5HFzcrywWQmuVnhOXe2kiWjIw= +github.com/moby/moby/client v0.4.0/go.mod h1:QWPbvWchQbxBNdaLSpoKpCdf5E+WxFAgNHogCWDoa7g= github.com/moby/spdystream v0.5.0 h1:7r0J1Si3QO/kjRitvSLVVFUjxMEb/YLj6S9FF62JBCU= github.com/moby/spdystream v0.5.0/go.mod h1:xBAYlnt/ay+11ShkdFKNAG7LsyK/tmNBVvVOwrfMgdI= -github.com/moby/sys/atomicwriter v0.1.0 h1:kw5D/EqkBwsBFi0ss9v1VG3wIkVhzGvLklJ+w3A14Sw= -github.com/moby/sys/atomicwriter v0.1.0/go.mod h1:Ul8oqv2ZMNHOceF643P6FKPXeCmYtlQMvpizfsSoaWs= -github.com/moby/sys/sequential v0.6.0 h1:qrx7XFUd/5DxtqcoH1h438hF5TmOvzC/lspjy7zgvCU= -github.com/moby/sys/sequential v0.6.0/go.mod h1:uyv8EUTrca5PnDsdMGXhZe6CCe8U/UiTWd+lL+7b/Ko= github.com/moby/term v0.5.2 h1:6qk3FJAFDs6i/q3W/pQ97SX192qKfZgGjCQqfCJkgzQ= github.com/moby/term v0.5.2/go.mod h1:d3djjFCrjnB+fl8NJux+EJzu0msscUP+f8it8hPkFLc= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -891,8 +891,6 @@ github.com/monochromegane/go-gitignore v0.0.0-20200626010858-205db1a8cc00 h1:n6/ github.com/monochromegane/go-gitignore v0.0.0-20200626010858-205db1a8cc00/go.mod h1:Pm3mSP3c5uWn86xMLZ5Sa7JB9GsEZySvHYXCTK4E9q4= github.com/moricho/tparallel v0.3.1 h1:fQKD4U1wRMAYNngDonW5XupoB/ZGJHdpzrWqgyg9krA= github.com/moricho/tparallel v0.3.1/go.mod h1:leENX2cUv7Sv2qDgdi0D0fCftN8fRC67Bcn8pqzeYNI= -github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= -github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= @@ -968,8 +966,6 @@ github.com/poy/onpar v1.1.2 h1:QaNrNiZx0+Nar5dLgTVp5mXkyoVFIbepjyEoGSnhbAY= github.com/poy/onpar v1.1.2/go.mod h1:6X8FLNoxyr9kkmnlqpK6LSoiOtrO6MICtWwEuWkLjzg= github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g= github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U= -github.com/prometheus-community/pro-bing v0.4.0 h1:YMbv+i08gQz97OZZBwLyvmmQEEzyfyrrjEaAchdy3R4= -github.com/prometheus-community/pro-bing v0.4.0/go.mod h1:b7wRYZtCcPmt4Sz319BykUU241rWLe1VFXyiyWK/dH4= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= github.com/prometheus/client_golang v1.4.0/go.mod h1:e9GMxYsXl05ICDXkRhurwBS4Q3OK1iX/F2sw+iXX5zU= @@ -1021,6 +1017,8 @@ github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRl github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis= github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= +github.com/robert-nix/ansihtml v1.0.1 h1:VTiyQ6/+AxSJoSSLsMecnkh8i0ZqOEdiRl/odOc64fc= +github.com/robert-nix/ansihtml v1.0.1/go.mod h1:CJwclxYaTPc2RfcxtanEACsYuTksh4yDXcNeHHKZINE= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= @@ -1063,8 +1061,8 @@ github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPx github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= -github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= -github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w= +github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g= github.com/sivchari/containedctx v1.0.3 h1:x+etemjbsh2fB5ewm5FeLNi5bUjK0V8n0RB+Wwfd0XE= github.com/sivchari/containedctx v1.0.3/go.mod h1:c1RDvCbnJLtH4lLcYD/GqwiBSSf4F5Qk0xld2rBqzJ4= github.com/sivchari/tenv v1.7.1 h1:PSpuD4bu6fSmtWMxSGWcvqUUgIn7k3yOJhOIzVWn8Ak= @@ -1090,8 +1088,9 @@ github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiT github.com/spf13/jwalterweatherman v1.1.0 h1:ue6voC5bR5F8YxI5S67j9i582FU4Qvo2bmqnqMYADFk= github.com/spf13/jwalterweatherman v1.1.0/go.mod h1:aNWZUN0dPAAO/Ljvb5BEdw96iTZ0EXowPYD95IqWIGo= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= -github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY= github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= +github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.16.0 h1:rGGH0XDZhdUOryiDWjmIvUSWpbNqisK8Wk0Vyefw8hc= github.com/spf13/viper v1.16.0/go.mod h1:yg78JgCJcbrQOvV9YLXgkLaZqUidkY9K+Dd1FofRzQg= github.com/ssgreg/nlreturn/v2 v2.2.1 h1:X4XDI7jstt3ySqGU86YGAURbxw3oTDPK9sPEi6YEwQ0= @@ -1126,10 +1125,12 @@ github.com/subosito/gotenv v1.4.2 h1:X1TuBLAMDFbaTAChgCBLu3DU3UPyELpnF2jjJ2cz/S8 github.com/subosito/gotenv v1.4.2/go.mod h1:ayKnFf/c6rvx/2iiLrJUk1e6plDbT3edrFNGqEflhK0= github.com/t-yuki/gocover-cobertura v0.0.0-20180217150009-aaee18c8195c h1:+aPplBwWcHBo6q9xrfWdMrT9o4kltkmmvpemgIjep/8= github.com/t-yuki/gocover-cobertura v0.0.0-20180217150009-aaee18c8195c/go.mod h1:SbErYREK7xXdsRiigaQiQkI9McGRzYMvlKYaP3Nimdk= -github.com/tailscale/certstore v0.1.1-0.20231202035212-d3fa0460f47e h1:PtWT87weP5LWHEY//SWsYkSO3RWRZo4OSWagh3YD2vQ= -github.com/tailscale/certstore v0.1.1-0.20231202035212-d3fa0460f47e/go.mod h1:XrBNfAFN+pwoWuksbFS9Ccxnopa15zJGgXRFN90l3K4= +github.com/tailscale/certstore v0.1.1-0.20260409135935-3638fb84b77d h1:JcGKBZAL7ePLwOhUdN8qGQZlP5GueEiIZwY7R62pejE= +github.com/tailscale/certstore v0.1.1-0.20260409135935-3638fb84b77d/go.mod h1:XrBNfAFN+pwoWuksbFS9Ccxnopa15zJGgXRFN90l3K4= github.com/tailscale/depaware v0.0.0-20251001183927-9c2ad255ef3f h1:PDPGJtm9PFBLNudHGwkfUGp/FWvP+kXXJ0D1pB35F40= github.com/tailscale/depaware v0.0.0-20251001183927-9c2ad255ef3f/go.mod h1:p9lPsd+cx33L3H9nNoecRRxPssFKUwwI50I3pZ0yT+8= +github.com/tailscale/gliderssh v0.3.4-0.20260330083525-c1389c70ff89 h1:glgVc1ZYMjwN1Q/ITWeuSQyl029uayagaR2sjsifehc= +github.com/tailscale/gliderssh v0.3.4-0.20260330083525-c1389c70ff89/go.mod h1:wn16Km1EZOX4UEAyaZa3dBwfFGOJ7neck40NcwosJUw= github.com/tailscale/go-winio v0.0.0-20231025203758-c4f33415bf55 h1:Gzfnfk2TWrk8Jj4P4c1a3CtQyMaTVCznlkLZI++hok4= github.com/tailscale/go-winio v0.0.0-20231025203758-c4f33415bf55/go.mod h1:4k4QO+dQ3R5FofL+SanAUZe+/QfeK0+OIuwDIRu2vSg= github.com/tailscale/goexpect v0.0.0-20210902213824-6e8c725cea41 h1:/V2rCMMWcsjYaYO2MeovLw+ClP63OtXgCF2Y1eb8+Ns= @@ -1138,8 +1139,8 @@ github.com/tailscale/gokrazy-kernel v0.0.0-20240728225134-3d23beabda2e h1:tyUUge github.com/tailscale/gokrazy-kernel v0.0.0-20240728225134-3d23beabda2e/go.mod h1:7Mth+m9bq2IHusSsexMNyupHWPL8RxwOuSvBlSGtgDY= github.com/tailscale/golang-x-crypto v0.0.0-20250404221719-a5573b049869 h1:SRL6irQkKGQKKLzvQP/ke/2ZuB7Py5+XuqtOgSj+iMM= github.com/tailscale/golang-x-crypto v0.0.0-20250404221719-a5573b049869/go.mod h1:ikbF+YT089eInTp9f2vmvy4+ZVnW5hzX1q2WknxSprQ= -github.com/tailscale/hujson v0.0.0-20221223112325-20486734a56a h1:SJy1Pu0eH1C29XwJucQo73FrleVK6t4kYz4NVhp34Yw= -github.com/tailscale/hujson v0.0.0-20221223112325-20486734a56a/go.mod h1:DFSS3NAGHthKo1gTlmEcSBiZrRJXi28rLNd/1udP1c8= +github.com/tailscale/hujson v0.0.0-20260302212456-ecc657c15afd h1:Rf9uhF1+VJ7ZHqxrG8pJ6YacmHvVCmByDmGbAWCc/gA= +github.com/tailscale/hujson v0.0.0-20260302212456-ecc657c15afd/go.mod h1:EbW0wDK/qEUYI0A5bqq0C2kF8JTQwWONmGDBbzsxxHo= github.com/tailscale/mkctr v0.0.0-20260107121656-ea857e3e500b h1:QKqCnmp0qHWUHySySKjpuhZANzRn7XrTVZWUuUgJ3lQ= github.com/tailscale/mkctr v0.0.0-20260107121656-ea857e3e500b/go.mod h1:4st7fy3NTWcWsQdOC69JcHK4UXnncgcxSOvSR8aD8a0= github.com/tailscale/netlink v1.1.1-0.20240822203006-4d49adab4de7 h1:uFsXVBE9Qr4ZoF094vE6iYTLDl0qCiKzYXlL6UeWObU= @@ -1148,12 +1149,14 @@ github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc h1:24heQPtnFR+y github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc/go.mod h1:f93CXfllFsO9ZQVq+Zocb1Gp4G5Fz0b0rXHLOzt/Djc= github.com/tailscale/setec v0.0.0-20251203133219-2ab774e4129a h1:TApskGPim53XY5WRt5hX4DnO8V6CmVoimSklryIoGMM= github.com/tailscale/setec v0.0.0-20251203133219-2ab774e4129a/go.mod h1:+6WyG6kub5/5uPsMdYQuSti8i6F5WuKpFWLQnZt/Mms= +github.com/tailscale/ts-gokrazy v0.0.0-20260429180033-fe741c6deb44 h1:a6GdEBrBcDy/4XQ2CxKQvuCaKN8EFL5JTE7ZFOkXDzQ= +github.com/tailscale/ts-gokrazy v0.0.0-20260429180033-fe741c6deb44/go.mod h1:mu0sethAvP7xItcfBAxMJWiXZ3ZQ5qbKmjPYizOkSHE= github.com/tailscale/web-client-prebuilt v0.0.0-20250124233751-d4cd19a26976 h1:UBPHPtv8+nEAy2PD8RyAhOYvau1ek0HDJqLS/Pysi14= github.com/tailscale/web-client-prebuilt v0.0.0-20250124233751-d4cd19a26976/go.mod h1:agQPE6y6ldqCOui2gkIh7ZMztTkIQKH049tv8siLuNQ= github.com/tailscale/wf v0.0.0-20240214030419-6fbb0a674ee6 h1:l10Gi6w9jxvinoiq15g8OToDdASBni4CyJOdHY1Hr8M= github.com/tailscale/wf v0.0.0-20240214030419-6fbb0a674ee6/go.mod h1:ZXRML051h7o4OcI0d3AaILDIad/Xw0IkXaHM17dic1Y= -github.com/tailscale/wireguard-go v0.0.0-20250716170648-1d0488a3d7da h1:jVRUZPRs9sqyKlYHHzHjAqKN+6e/Vog6NpHYeNPJqOw= -github.com/tailscale/wireguard-go v0.0.0-20250716170648-1d0488a3d7da/go.mod h1:BOm5fXUBFM+m9woLNBoxI9TaBXXhGNP50LX/TGIvGb4= +github.com/tailscale/wireguard-go v0.0.0-20260427181203-e3ac4a0afb4e h1:GexFR7ak1iz26fxg8HWCpOEqAOL8UEZJ7J3JxeCalDs= +github.com/tailscale/wireguard-go v0.0.0-20260427181203-e3ac4a0afb4e/go.mod h1:6SerzcvHWQchKO2BfNdmquA77CHSECZuFl+D9fp4RnI= github.com/tailscale/xnet v0.0.0-20240729143630-8497ac4dab2e h1:zOGKqN5D5hHhiYUp091JqK7DPCqSARyUfduhGUY8Bek= github.com/tailscale/xnet v0.0.0-20240729143630-8497ac4dab2e/go.mod h1:orPd6JZXXRyuDusYilywte7k094d7dycXXU5YnWsrwg= github.com/tc-hib/winres v0.2.1 h1:YDE0FiP0VmtRaDn7+aaChp1KiF4owBiJa5l964l5ujA= @@ -1197,10 +1200,9 @@ github.com/uudashr/gocognit v1.1.2 h1:l6BAEKJqQH2UpKAPKdMfZf5kE4W/2xk8pfU1OVLvni github.com/uudashr/gocognit v1.1.2/go.mod h1:aAVdLURqcanke8h3vg35BC++eseDm66Z7KmchI5et4k= github.com/vbatts/tar-split v0.12.2 h1:w/Y6tjxpeiFMR47yzZPlPj/FcPLpXbTUi/9H7d3CPa4= github.com/vbatts/tar-split v0.12.2/go.mod h1:eF6B6i6ftWQcDqEn3/iGFRFRo8cBIMSJVOpnNdfTMFA= -github.com/vishvananda/netlink v1.3.1-0.20240922070040-084abd93d350 h1:w5OI+kArIBVksl8UGn6ARQshtPCQvDsbuA9NQie3GIg= -github.com/vishvananda/netlink v1.3.1-0.20240922070040-084abd93d350/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs= +github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0= +github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4= github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= -github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY= github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= @@ -1265,8 +1267,8 @@ go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.34.0 h1:OeNbIYk/2C15ckl7glB go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.34.0/go.mod h1:7Bept48yIeqxP2OZ9/AqIpYS94h2or0aB4FypJTc8ZM= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.34.0 h1:tgJ0uaNS4c98WRNUEx5U3aDlrDOI5Rs+1Vifcw4DJ8U= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.34.0/go.mod h1:U7HYyW0zt/a9x5J1Kjs+r1f/d4ZHnYFclhYY2+YbeoE= -go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.33.0 h1:wpMfgF8E1rkrT1Z6meFh1NDtownE9Ii3n3X2GJYjsaU= -go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.33.0/go.mod h1:wAy0T/dUbs468uOlkT31xjvqQgEVXv58BRFWEgn5v/0= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.32.0 h1:cMyu9O88joYEaI47CnQkxO1XZdpoTF9fEnW2duIddhw= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.32.0/go.mod h1:6Am3rn7P9TVVeXYG+wtcGE7IE1tsQ+bP3AuWcKt/gOI= go.opentelemetry.io/otel/exporters/prometheus v0.54.0 h1:rFwzp68QMgtzu9PgP3jm9XaMICI6TsofWWPcBDKwlsU= go.opentelemetry.io/otel/exporters/prometheus v0.54.0/go.mod h1:QyjcV9qDP6VeK5qPyKETvNjmaaEc7+gqjh4SS0ZYzDU= go.opentelemetry.io/otel/exporters/stdout/stdoutlog v0.8.0 h1:CHXNXwfKWfzS65yrlB2PVds1IBZcdsX8Vepy9of0iRU= @@ -1315,8 +1317,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw= -golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= -golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= +golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= +golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= golang.org/x/crypto/x509roots/fallback v0.0.0-20260113154411-7d0074ccc6f1 h1:EBHQuS9qI8xJ96+YRgVV2ahFLUYbWpt1rf3wPfXN2wQ= golang.org/x/crypto/x509roots/fallback v0.0.0-20260113154411-7d0074ccc6f1/go.mod h1:MEIPiCnxvQEjA4astfaKItNwEVZA5Ki+3+nyGbJ5N18= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -1366,8 +1368,8 @@ golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91 golang.org/x/mod v0.6.0/go.mod h1:4mET923SAdbXp2ki8ey+zGs1SLqsuM2Y0uvdZR/fUNI= golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk= -golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc= +golang.org/x/mod v0.35.0 h1:Ww1D637e6Pg+Zb2KrWfHQUnH2dQRLBQyAtpr/haaJeM= +golang.org/x/mod v0.35.0/go.mod h1:+GwiRhIInF8wPm+4AoT6L0FA1QWAad3OMdTRx4tFYlU= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -1407,16 +1409,16 @@ golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= golang.org/x/net v0.5.0/go.mod h1:DivGGAXEgPSlEBzxGzZI+ZLohi+xUj054jfeKui00ws= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= -golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= +golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA= +golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20210514164344-f6687ab2804c/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= -golang.org/x/oauth2 v0.33.0 h1:4Q+qn+E5z8gPRJfmRy7C2gGG3T4jIprK6aSYgTXGRpo= -golang.org/x/oauth2 v0.33.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= +golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -1430,8 +1432,8 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= -golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -1497,18 +1499,18 @@ golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= -golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/telemetry v0.0.0-20251111182119-bc8e575c7b54 h1:E2/AqCUMZGgd73TQkxUMcMla25GB9i/5HOdLr+uH7Vo= -golang.org/x/telemetry v0.0.0-20251111182119-bc8e575c7b54/go.mod h1:hKdjCMrbv9skySur+Nek8Hd0uJ0GuxJIoIX2payrIdQ= +golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= +golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/telemetry v0.0.0-20260409153401-be6f6cb8b1fa h1:efT73AJZfAAUV7SOip6pWGkwJDzIGiKBZGVzHYa+ve4= +golang.org/x/telemetry v0.0.0-20260409153401-be6f6cb8b1fa/go.mod h1:kHjTxDEnAu6/Nl9lDkzjWpR+bmKfxeiRuSDlsMb70gE= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= golang.org/x/term v0.4.0/go.mod h1:9P2UbLfCdcvo3p/nzKvsmas4TnlujnuoV9hGgYzW1lQ= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q= -golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg= +golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY= +golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -1519,8 +1521,8 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= -golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= +golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= +golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -1590,8 +1592,8 @@ golang.org/x/tools v0.2.0/go.mod h1:y4OqIKeOV/fWJetJ8bXPU1sEVniLMIyDAZWeHdV+NTA= golang.org/x/tools v0.3.0/go.mod h1:/rWhSS2+zyEVwoJf8YAX6L2f0ntZ7Kn/mGgAWcipA5k= golang.org/x/tools v0.5.0/go.mod h1:N+Kgy78s5I24c24dU8OfWNEotWjutIs8SnJvn5IDq+k= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ= -golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ= +golang.org/x/tools v0.44.0 h1:UP4ajHPIcuMjT1GqzDWRlalUEoY+uzoZKnhOjbIPD2c= +golang.org/x/tools v0.44.0/go.mod h1:KA0AfVErSdxRZIsOVipbv3rQhVXTnlU6UhKxHd1seDI= golang.org/x/tools/go/expect v0.1.1-deprecated h1:jpBZDwmgPhXsKZC6WhL20P4b/wmnpsEAGHaNy0n/rJM= golang.org/x/tools/go/expect v0.1.1-deprecated/go.mod h1:eihoPOH+FgIqa3FpoTwguz/bVUSGBlGQU67vpBeOrBY= golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated h1:1h2MnaIAIXISqTFKdENegdpAgUXz6NrPEsbIeWaBRvM= @@ -1724,8 +1726,8 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gotest.tools/v3 v3.4.0 h1:ZazjZUfuVeZGLAmlKKuyv3IKP5orXcwtOwDQH6YVr6o= -gotest.tools/v3 v3.4.0/go.mod h1:CtbdzLSsqVhDgMtKsx03ird5YTGB3ar27v0u/yKBW5g= +gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q= +gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA= gvisor.dev/gvisor v0.0.0-20260224225140-573d5e7127a8 h1:Zy8IV/+FMLxy6j6p87vk/vQGKcdnbprwjTxc8UiUtsA= gvisor.dev/gvisor v0.0.0-20260224225140-573d5e7127a8/go.mod h1:QkHjoMIBaYtpVufgwv3keYAbln78mBoCuShZrPrer1Q= helm.sh/helm/v3 v3.19.0 h1:krVyCGa8fa/wzTZgqw0DUiXuRT5BPdeqE/sQXujQ22k= @@ -1737,8 +1739,8 @@ honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWh honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= -honnef.co/go/tools v0.7.0-0.dev.0.20251022135355-8273271481d0 h1:5SXjd4ET5dYijLaf0O3aOenC0Z4ZafIWSpjUzsQaNho= -honnef.co/go/tools v0.7.0-0.dev.0.20251022135355-8273271481d0/go.mod h1:EPDDhEZqVHhWuPI5zPAsjU0U7v9xNIWjoOVyZ5ZcniQ= +honnef.co/go/tools v0.7.0 h1:w6WUp1VbkqPEgLz4rkBzH/CSU6HkoqNLp6GstyTx3lU= +honnef.co/go/tools v0.7.0/go.mod h1:pm29oPxeP3P82ISxZDgIYeOaf9ta6Pi0EWvCFoLG2vc= howett.net/plist v1.0.0 h1:7CrbWYbPPO/PyNy38b2EB/+gYbjCe2DXBxgtOOZbSQM= howett.net/plist v1.0.0/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g= k8s.io/api v0.34.0 h1:L+JtP2wDbEYPUeNGbeSa/5GwFtIA662EmT2YSLOkAVE= @@ -1769,6 +1771,8 @@ mvdan.cc/unparam v0.0.0-20240104100049-c549a3470d14 h1:zCr3iRRgdk5eIikZNDphGcM6K mvdan.cc/unparam v0.0.0-20240104100049-c549a3470d14/go.mod h1:ZzZjEpJDOmx8TdVU6umamY3Xy0UAQUI2DHbf05USVbI= oras.land/oras-go/v2 v2.6.0 h1:X4ELRsiGkrbeox69+9tzTu492FMUu7zJQW6eJU+I2oc= oras.land/oras-go/v2 v2.6.0/go.mod h1:magiQDfG6H1O9APp+rOsvCPcW1GD2MM7vgnKY0Y+u1o= +pgregory.net/rapid v1.2.0 h1:keKAYRcjm+e1F0oAuU5F5+YPAWcyxNNRK2wud503Gnk= +pgregory.net/rapid v1.2.0/go.mod h1:PY5XlDGj0+V1FCq0o192FdRhpKHGTRIWBgqjDBTrq04= rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= @@ -1792,3 +1796,5 @@ sigs.k8s.io/yaml v1.6.0 h1:G8fkbMSAFqgEFgh4b1wmtzDnioxFCUgTZhlbj5P9QYs= sigs.k8s.io/yaml v1.6.0/go.mod h1:796bPqUfzR/0jLAl6XjHl3Ck7MiyVv8dbTdyT3/pMf4= software.sslmate.com/src/go-pkcs12 v0.4.0 h1:H2g08FrTvSFKUj+D309j1DPfk5APnIdAQAB8aEykJ5k= software.sslmate.com/src/go-pkcs12 v0.4.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI= +tailscale.com/client/tailscale/v2 v2.9.0 h1:zBZIIeIYXL42qvvile7d29O2DKSr3AfNc2gzd1JCf2o= +tailscale.com/client/tailscale/v2 v2.9.0/go.mod h1:FGjvGT3ThHelqo0gfdK3IN3k1dwNbRzYbQh2XO3C47U= diff --git a/gokrazy/Makefile b/gokrazy/Makefile index bc55f2a52acb5..014866851bf74 100644 --- a/gokrazy/Makefile +++ b/gokrazy/Makefile @@ -11,3 +11,8 @@ qemu: image natlab: go run build.go --build --app=natlabapp qemu-img convert -O qcow2 natlabapp.img natlabapp.qcow2 + +# For natlab integration tests on macOS arm64: +natlab-arm64: + go run build.go --build --app=natlabapp.arm64 + qemu-img convert -O qcow2 natlabapp.arm64.img natlabapp.arm64.qcow2 diff --git a/gokrazy/gokrazy_test.go b/gokrazy/gokrazy_test.go new file mode 100644 index 0000000000000..76398d49bf594 --- /dev/null +++ b/gokrazy/gokrazy_test.go @@ -0,0 +1,286 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "encoding/json" + "flag" + "fmt" + "hash/fnv" + "io" + "net" + "os" + "os/exec" + "path/filepath" + "strings" + "sync" + "testing" + "time" + + "golang.org/x/mod/modfile" +) + +var runVMTests = flag.Bool("run-vm-tests", false, "run tests that require a VM") + +func findKernelPath(t *testing.T) string { + t.Helper() + goModPath := filepath.Join("..", "go.mod") + b, err := os.ReadFile(goModPath) + if err != nil { + t.Fatalf("reading go.mod: %v", err) + } + mf, err := modfile.Parse("go.mod", b, nil) + if err != nil { + t.Fatalf("parsing go.mod: %v", err) + } + goModB, err := exec.Command("go", "env", "GOMODCACHE").CombinedOutput() + if err != nil { + t.Fatalf("go env GOMODCACHE: %v", err) + } + for _, r := range mf.Require { + if r.Mod.Path == "github.com/tailscale/gokrazy-kernel" { + return strings.TrimSpace(string(goModB)) + "/" + r.Mod.String() + "/vmlinuz" + } + } + t.Fatal("failed to find gokrazy-kernel in go.mod") + return "" +} + +// gptPartuuid returns the GPT PARTUUID for a gokrazy appliance partition, +// matching the scheme used by monogok: fnv32a(hostname) formatted into +// the gokrazy GUID prefix. +func gptPartuuid(hostname string, partition uint16) string { + h := fnv.New32a() + h.Write([]byte(hostname)) + return fmt.Sprintf("60c24cc1-f3f9-427a-8199-%08x00%02x", h.Sum32(), partition) +} + +func buildTsappImage(t *testing.T) string { + t.Helper() + imgPath, err := filepath.Abs("tsapp.img") + if err != nil { + t.Fatal(err) + } + if _, err := os.Stat(imgPath); err == nil { + t.Logf("using existing tsapp.img: %s", imgPath) + return imgPath + } + + t.Logf("building tsapp.img...") + cmd := exec.Command("make", "image") + cmd.Dir, _ = os.Getwd() + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + t.Fatalf("make image: %v", err) + } + if _, err := os.Stat(imgPath); err != nil { + t.Fatalf("tsapp.img not found after build: %v", err) + } + return imgPath +} + +// serialLog collects serial console output in a thread-safe manner. +type serialLog struct { + mu sync.Mutex + lines []string +} + +func (sl *serialLog) add(line string) { + sl.mu.Lock() + defer sl.mu.Unlock() + sl.lines = append(sl.lines, line) +} + +func (sl *serialLog) lastN(n int) []string { + sl.mu.Lock() + defer sl.mu.Unlock() + if len(sl.lines) <= n { + cp := make([]string, len(sl.lines)) + copy(cp, sl.lines) + return cp + } + cp := make([]string, n) + copy(cp, sl.lines[len(sl.lines)-n:]) + return cp +} + +func (sl *serialLog) findLine(pred func(string) bool) bool { + sl.mu.Lock() + defer sl.mu.Unlock() + for _, line := range sl.lines { + if pred(line) { + return true + } + } + return false +} + +// TestBusyboxInTsapp boots the tsapp image in QEMU and verifies that +// busybox is accessible via the serial console shell. This validates +// that the serial-busybox package's extra files (the busybox binary) +// are properly included in the image by monogok. +func TestBusyboxInTsapp(t *testing.T) { + if !*runVMTests { + t.Skip("skipping VM test; set --run-vm-tests to run") + } + + kernel := findKernelPath(t) + if _, err := os.Stat(kernel); err != nil { + t.Skipf("kernel not found at %s: %v", kernel, err) + } + t.Logf("kernel: %s", kernel) + + // Read the hostname from config.json to compute the GPT PARTUUID. + cfgBytes, err := os.ReadFile("tsapp/config.json") + if err != nil { + t.Fatalf("reading tsapp/config.json: %v", err) + } + var cfg struct { + Hostname string + } + if err := json.Unmarshal(cfgBytes, &cfg); err != nil { + t.Fatalf("parsing config.json: %v", err) + } + rootParam := fmt.Sprintf("root=PARTUUID=%s/PARTNROFF=1", gptPartuuid(cfg.Hostname, 1)) + t.Logf("root param: %s", rootParam) + + imgPath := buildTsappImage(t) + + // Create a temporary qcow2 overlay so we don't modify the original image. + tmpDir := t.TempDir() + disk := filepath.Join(tmpDir, "tsapp-test.qcow2") + out, err := exec.Command("qemu-img", "create", + "-f", "qcow2", + "-F", "raw", + "-b", imgPath, + disk).CombinedOutput() + if err != nil { + t.Fatalf("qemu-img create: %v, %s", err, out) + } + + // Set up a Unix socket for the serial console. + sockPath := filepath.Join(tmpDir, "serial.sock") + ln, err := net.Listen("unix", sockPath) + if err != nil { + t.Fatalf("listen: %v", err) + } + defer ln.Close() + + // Boot QEMU with microvm, explicit kernel, and serial via virtconsole + // connected to our Unix socket. The kernel sees hvc0 as the console + // device, and gokrazy uses it for the serial shell. + cmd := exec.Command("qemu-system-x86_64", + "-M", "microvm,isa-serial=off", + "-m", "1G", + "-nodefaults", "-no-user-config", "-nographic", + "-kernel", kernel, + "-append", "console=hvc0 "+rootParam+" ro init=/gokrazy/init panic=10 oops=panic pci=off nousb tsc=unstable clocksource=hpet", + "-drive", "id=blk0,file="+disk+",format=qcow2", + "-device", "virtio-blk-device,drive=blk0", + "-device", "virtio-rng-device", + "-device", "virtio-serial-device", + "-chardev", "socket,id=virtiocon0,path="+sockPath+",server=off", + "-device", "virtconsole,chardev=virtiocon0", + "-netdev", "user,id=net0", + "-device", "virtio-net-device,netdev=net0", + ) + cmd.Stderr = os.Stderr + if err := cmd.Start(); err != nil { + t.Fatalf("qemu start: %v", err) + } + t.Cleanup(func() { + cmd.Process.Kill() + cmd.Wait() + }) + + // Accept the serial console connection from QEMU. + ln.(*net.UnixListener).SetDeadline(time.Now().Add(30 * time.Second)) + conn, err := ln.Accept() + if err != nil { + t.Fatalf("accept serial connection: %v", err) + } + defer conn.Close() + + // Read serial output in a goroutine. + slog := &serialLog{} + bootDone := make(chan struct{}) + go func() { + buf := make([]byte, 4096) + var partial string + for { + n, err := conn.Read(buf) + if n > 0 { + partial += string(buf[:n]) + for { + idx := strings.IndexByte(partial, '\n') + if idx < 0 { + break + } + line := strings.TrimRight(partial[:idx], "\r") + partial = partial[idx+1:] + slog.add(line) + t.Logf("serial: %s", line) + // gokrazy logs socket listener info when boot is done. + if strings.Contains(line, "listening on") { + select { + case <-bootDone: + default: + close(bootDone) + } + } + } + } + if err != nil { + if err != io.EOF { + t.Logf("serial read error: %v", err) + } + return + } + } + }() + + // Wait for boot to complete (up to 120 seconds). + select { + case <-bootDone: + t.Logf("boot complete") + case <-time.After(120 * time.Second): + t.Fatalf("timeout waiting for boot; last lines:\n%s", + strings.Join(slog.lastN(20), "\n")) + } + + // Small delay to let services fully initialize. + time.Sleep(2 * time.Second) + + // Send a newline to trigger the serial shell. + // gokrazy's init reads stdin and calls tryStartShell() on any input. + fmt.Fprintf(conn, "\n") + time.Sleep(2 * time.Second) + + // Send a command to test busybox. The echo command is a busybox builtin, + // so if busybox is working, we'll see our marker in the output. + marker := "BUSYBOX_TEST_OK_12345" + fmt.Fprintf(conn, "echo %s\n", marker) + + // Wait for our marker in the output (not on the echo command line itself). + deadline := time.After(15 * time.Second) + for { + select { + case <-deadline: + t.Fatalf("timeout waiting for busybox echo response; busybox binary is likely missing from the image.\n"+ + "This indicates monogok is not copying _gokrazy/extrafiles from serial-busybox.\n"+ + "Last serial lines:\n%s", + strings.Join(slog.lastN(30), "\n")) + default: + } + time.Sleep(200 * time.Millisecond) + // Look for the marker on a line by itself (the echo output, not the command). + if slog.findLine(func(line string) bool { + return strings.TrimSpace(line) == marker + }) { + t.Logf("busybox shell is working: got echo response") + return // success + } + } +} diff --git a/gokrazy/natlabapp.arm64/config.json b/gokrazy/natlabapp.arm64/config.json index 2ba9a20f9510f..8283dc053dc31 100644 --- a/gokrazy/natlabapp.arm64/config.json +++ b/gokrazy/natlabapp.arm64/config.json @@ -27,5 +27,7 @@ "KernelPackage": "github.com/gokrazy/kernel.arm64", "FirmwarePackage": "github.com/gokrazy/kernel.arm64", "EEPROMPackage": "", - "InternalCompatibilityFlags": {} + "InternalCompatibilityFlags": { + "InitImportPath": "github.com/tailscale/ts-gokrazy/gokrazyinit" + } } diff --git a/gokrazy/natlabapp.arm64/gokrazydeps.go b/gokrazy/natlabapp.arm64/gokrazydeps.go new file mode 100644 index 0000000000000..001ab89b840ea --- /dev/null +++ b/gokrazy/natlabapp.arm64/gokrazydeps.go @@ -0,0 +1,16 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build for_go_mod_tidy_only + +package gokrazydeps + +import ( + _ "github.com/gokrazy/gokrazy/cmd/dhcp" + _ "github.com/gokrazy/kernel.arm64" + _ "github.com/gokrazy/serial-busybox" + _ "github.com/tailscale/ts-gokrazy/gokrazyinit" + _ "tailscale.com/cmd/tailscale" + _ "tailscale.com/cmd/tailscaled" + _ "tailscale.com/cmd/tta" +) diff --git a/gokrazy/natlabapp/config.json b/gokrazy/natlabapp/config.json index 1968b2aac79f8..c46f018794934 100644 --- a/gokrazy/natlabapp/config.json +++ b/gokrazy/natlabapp/config.json @@ -27,5 +27,7 @@ "KernelPackage": "github.com/tailscale/gokrazy-kernel", "FirmwarePackage": "", "EEPROMPackage": "", - "InternalCompatibilityFlags": {} + "InternalCompatibilityFlags": { + "InitImportPath": "github.com/tailscale/ts-gokrazy/gokrazyinit" + } } diff --git a/gokrazy/natlabapp/gokrazydeps.go b/gokrazy/natlabapp/gokrazydeps.go index c5d2b32a3d543..2e4c1361c48ce 100644 --- a/gokrazy/natlabapp/gokrazydeps.go +++ b/gokrazy/natlabapp/gokrazydeps.go @@ -6,10 +6,10 @@ package gokrazydeps import ( - _ "github.com/gokrazy/gokrazy" _ "github.com/gokrazy/gokrazy/cmd/dhcp" _ "github.com/gokrazy/serial-busybox" _ "github.com/tailscale/gokrazy-kernel" + _ "github.com/tailscale/ts-gokrazy/gokrazyinit" _ "tailscale.com/cmd/tailscale" _ "tailscale.com/cmd/tailscaled" _ "tailscale.com/cmd/tta" diff --git a/gokrazy/tsapp/config.json b/gokrazy/tsapp/config.json index b88be53a456a8..15533afd1136a 100644 --- a/gokrazy/tsapp/config.json +++ b/gokrazy/tsapp/config.json @@ -33,5 +33,7 @@ ], "KernelPackage": "github.com/tailscale/gokrazy-kernel", "FirmwarePackage": "github.com/tailscale/gokrazy-kernel", - "InternalCompatibilityFlags": {} + "InternalCompatibilityFlags": { + "InitImportPath": "github.com/tailscale/ts-gokrazy/gokrazyinit" + } } diff --git a/gokrazy/tsapp/gokrazydeps.go b/gokrazy/tsapp/gokrazydeps.go index 931080647f8e5..22bdc3a499425 100644 --- a/gokrazy/tsapp/gokrazydeps.go +++ b/gokrazy/tsapp/gokrazydeps.go @@ -7,12 +7,12 @@ package gokrazydeps import ( _ "github.com/gokrazy/breakglass" - _ "github.com/gokrazy/gokrazy" _ "github.com/gokrazy/gokrazy/cmd/dhcp" _ "github.com/gokrazy/gokrazy/cmd/ntp" _ "github.com/gokrazy/gokrazy/cmd/randomd" _ "github.com/gokrazy/serial-busybox" _ "github.com/tailscale/gokrazy-kernel" + _ "github.com/tailscale/ts-gokrazy/gokrazyinit" _ "tailscale.com/cmd/tailscale" _ "tailscale.com/cmd/tailscaled" ) diff --git a/health/health.go b/health/health.go index 0cfe570c4296a..1829bd482ad6f 100644 --- a/health/health.go +++ b/health/health.go @@ -132,6 +132,11 @@ type Tracker struct { localLogConfigErr error tlsConnectionErrors map[string]error // map[ServerName]error metricHealthMessage any // nil or *metrics.MultiLabelMap[metricHealthMessageLabel] + + // IP forwarding check + // If non-nil, called periodically to check if IP forwarding is broken. + // Should return true if broken, false if healthy. + isIPForwardingBroken func() bool } // NewTracker contructs a new [Tracker] and attaches the given eventbus. @@ -1097,6 +1102,8 @@ func (t *Tracker) updateBuiltinWarnablesLocked() { t.setHealthyLocked(NetworkStatusWarnable) } + t.updateIPForwardingWarnableLocked() + if t.localLogConfigErr != nil { t.setUnhealthyLocked(localLogWarnable, Args{ ArgError: t.localLogConfigErr.Error(), @@ -1389,3 +1396,29 @@ func (t *Tracker) LastNoiseDialWasRecent() bool { t.lastNoiseDial = now return dur < 2*time.Minute } + +// SetIPForwardingCheck sets the function to check if IP forwarding is broken. +// The function should return true if IP forwarding is broken, false if healthy. +// Pass nil to disable IP forwarding checks. +func (t *Tracker) SetIPForwardingCheck(checkFunc func() bool) { + if t.nil() { + return + } + t.mu.Lock() + defer t.mu.Unlock() + + t.isIPForwardingBroken = checkFunc + + // Run an immediate check to set initial state + t.updateIPForwardingWarnableLocked() +} + +// updateIPForwardingWarnableLocked checks the IP forwarding state and +// sets or clears the ipForwardingWarnable accordingly. +func (t *Tracker) updateIPForwardingWarnableLocked() { + if t.isIPForwardingBroken != nil && t.isIPForwardingBroken() { + t.setUnhealthyLocked(ipForwardingWarnable, Args{}) + } else { + t.setHealthyLocked(ipForwardingWarnable) + } +} diff --git a/health/health_test.go b/health/health_test.go index 953c4dca26ea3..ccd49b19af360 100644 --- a/health/health_test.go +++ b/health/health_test.go @@ -82,8 +82,7 @@ func TestAppendWarnableDebugFlags(t *testing.T) { func TestNilMethodsDontCrash(t *testing.T) { var nilt *Tracker rv := reflect.ValueOf(nilt) - for i := 0; i < rv.NumMethod(); i++ { - mt := rv.Type().Method(i) + for mt, method := range rv.Methods() { t.Logf("calling Tracker.%s ...", mt.Name) var args []reflect.Value for j := 0; j < mt.Type.NumIn(); j++ { @@ -92,7 +91,7 @@ func TestNilMethodsDontCrash(t *testing.T) { } args = append(args, reflect.Zero(mt.Type.In(j))) } - rv.Method(i).Call(args) + method.Call(args) } } @@ -389,7 +388,7 @@ func TestShowUpdateWarnable(t *testing.T) { wantShow bool }{ { - desc: "nil ClientVersion", + desc: "nil-ClientVersion", check: true, cv: nil, wantWarnable: nil, @@ -403,35 +402,35 @@ func TestShowUpdateWarnable(t *testing.T) { wantShow: false, }, { - desc: "no LatestVersion", + desc: "no-LatestVersion", check: true, cv: &tailcfg.ClientVersion{RunningLatest: false, LatestVersion: ""}, wantWarnable: nil, wantShow: false, }, { - desc: "show regular update", + desc: "show-regular-update", check: true, cv: &tailcfg.ClientVersion{RunningLatest: false, LatestVersion: "1.2.3"}, wantWarnable: updateAvailableWarnable, wantShow: true, }, { - desc: "show security update", + desc: "show-security-update", check: true, cv: &tailcfg.ClientVersion{RunningLatest: false, LatestVersion: "1.2.3", UrgentSecurityUpdate: true}, wantWarnable: securityUpdateAvailableWarnable, wantShow: true, }, { - desc: "update check disabled", + desc: "update-check-disabled", check: false, cv: &tailcfg.ClientVersion{RunningLatest: false, LatestVersion: "1.2.3"}, wantWarnable: nil, wantShow: false, }, { - desc: "hide update with auto-updates", + desc: "hide-update-with-auto-updates", check: true, apply: opt.NewBool(true), cv: &tailcfg.ClientVersion{RunningLatest: false, LatestVersion: "1.2.3"}, @@ -439,7 +438,7 @@ func TestShowUpdateWarnable(t *testing.T) { wantShow: false, }, { - desc: "show security update with auto-updates", + desc: "show-security-update-with-auto-updates", check: true, apply: opt.NewBool(true), cv: &tailcfg.ClientVersion{RunningLatest: false, LatestVersion: "1.2.3", UrgentSecurityUpdate: true}, @@ -623,7 +622,7 @@ func TestControlHealth(t *testing.T) { } }) - t.Run("Strings()", func(t *testing.T) { + t.Run("Strings", func(t *testing.T) { wantStrs := []string{ "Control health message: Extra help.", "Control health message: Extra help. Learn more: http://www.example.com", @@ -1000,3 +999,86 @@ func TestCurrentStateETagWarnable(t *testing.T) { } }) } + +func TestIPForwardingState(t *testing.T) { + tests := []struct { + name string + checkFunc func() bool // nil means no check function + wantUnhealthy bool + }{ + { + name: "broken", + checkFunc: func() bool { return true }, + wantUnhealthy: true, + }, + { + name: "healthy", + checkFunc: func() bool { return false }, + wantUnhealthy: false, + }, + { + name: "no_check_function", + checkFunc: nil, + wantUnhealthy: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + bus := eventbus.New() + tr := NewTracker(bus) + defer bus.Close() + + tr.SetIPNState("Running", true) + tr.SetIPForwardingCheck(tt.checkFunc) + + tr.mu.Lock() + tr.updateBuiltinWarnablesLocked() + tr.mu.Unlock() + + got := tr.IsUnhealthy(ipForwardingWarnable) + if got != tt.wantUnhealthy { + t.Errorf("IsUnhealthy(ipForwardingWarnable) = %v, want %v", got, tt.wantUnhealthy) + } + }) + } + + // Test state transitions + t.Run("transitions", func(t *testing.T) { + bus := eventbus.New() + tr := NewTracker(bus) + defer bus.Close() + + tr.SetIPNState("Running", true) + + // Start broken + tr.SetIPForwardingCheck(func() bool { return true }) + tr.mu.Lock() + tr.updateBuiltinWarnablesLocked() + tr.mu.Unlock() + + if !tr.IsUnhealthy(ipForwardingWarnable) { + t.Fatal("expected IP forwarding to be unhealthy initially") + } + + // Transition to healthy + tr.SetIPForwardingCheck(func() bool { return false }) + tr.mu.Lock() + tr.updateBuiltinWarnablesLocked() + tr.mu.Unlock() + + if tr.IsUnhealthy(ipForwardingWarnable) { + t.Fatal("expected IP forwarding to be healthy after transition") + } + + // Transition to nil (should stay healthy) + tr.SetIPForwardingCheck(nil) + tr.mu.Lock() + tr.updateBuiltinWarnablesLocked() + tr.mu.Unlock() + + if tr.IsUnhealthy(ipForwardingWarnable) { + t.Fatal("expected IP forwarding to be healthy after clearing check") + } + }) +} diff --git a/health/warnings.go b/health/warnings.go index fc9099af2ecc7..416cb8ab0cc70 100644 --- a/health/warnings.go +++ b/health/warnings.go @@ -298,3 +298,16 @@ var warmingUpWarnable = condRegister(func() *Warnable { Text: StaticMessage("Tailscale is starting. Please wait."), } }) + +// ipForwardingWarnable is a Warnable that warns the user that IP forwarding is disabled +// but subnet routing or exit node functionality is being used. +var ipForwardingWarnable = condRegister(func() *Warnable { + return &Warnable{ + Code: "ip-forwarding-off", + Title: "IP forwarding is off", + Severity: SeverityMedium, + MapDebugFlag: "warn-ip-forwarding-off", + Text: StaticMessage("Subnet routing is enabled, but IP forwarding is disabled. Check that IP forwarding is enabled on your machine."), + ImpactsConnectivity: true, + } +}) diff --git a/hostinfo/hostinfo.go b/hostinfo/hostinfo.go index f91f52ec0c3d8..11b0a25ccc238 100644 --- a/hostinfo/hostinfo.go +++ b/hostinfo/hostinfo.go @@ -23,7 +23,6 @@ import ( "tailscale.com/tailcfg" "tailscale.com/types/lazy" "tailscale.com/types/opt" - "tailscale.com/types/ptr" "tailscale.com/util/cloudenv" "tailscale.com/util/dnsname" "tailscale.com/util/lineiter" @@ -93,8 +92,8 @@ func condCall[T any](fn func() T) T { } var ( - lazyInContainer = &lazyAtomicValue[opt.Bool]{f: ptr.To(inContainer)} - lazyGoArchVar = &lazyAtomicValue[string]{f: ptr.To(goArchVar)} + lazyInContainer = &lazyAtomicValue[opt.Bool]{f: new(inContainer)} + lazyGoArchVar = &lazyAtomicValue[string]{f: new(goArchVar)} ) type lazyAtomicValue[T any] struct { diff --git a/hostinfo/hostinfo_darwin.go b/hostinfo/hostinfo_darwin.go index cd551ca425790..338ab9792c215 100644 --- a/hostinfo/hostinfo_darwin.go +++ b/hostinfo/hostinfo_darwin.go @@ -10,7 +10,6 @@ import ( "path/filepath" "golang.org/x/sys/unix" - "tailscale.com/types/ptr" ) func init() { @@ -19,7 +18,7 @@ func init() { } var ( - lazyOSVersion = &lazyAtomicValue[string]{f: ptr.To(osVersionDarwin)} + lazyOSVersion = &lazyAtomicValue[string]{f: new(osVersionDarwin)} ) func packageTypeDarwin() string { diff --git a/hostinfo/hostinfo_freebsd.go b/hostinfo/hostinfo_freebsd.go index 3a214ed2463cb..580d97a6d1027 100644 --- a/hostinfo/hostinfo_freebsd.go +++ b/hostinfo/hostinfo_freebsd.go @@ -11,7 +11,6 @@ import ( "os/exec" "golang.org/x/sys/unix" - "tailscale.com/types/ptr" "tailscale.com/version/distro" ) @@ -22,8 +21,8 @@ func init() { } var ( - lazyVersionMeta = &lazyAtomicValue[versionMeta]{f: ptr.To(freebsdVersionMeta)} - lazyOSVersion = &lazyAtomicValue[string]{f: ptr.To(osVersionFreeBSD)} + lazyVersionMeta = &lazyAtomicValue[versionMeta]{f: new(freebsdVersionMeta)} + lazyOSVersion = &lazyAtomicValue[string]{f: new(osVersionFreeBSD)} ) func distroNameFreeBSD() string { diff --git a/hostinfo/hostinfo_linux.go b/hostinfo/hostinfo_linux.go index bb9a5c58c1bb0..6b21d81529264 100644 --- a/hostinfo/hostinfo_linux.go +++ b/hostinfo/hostinfo_linux.go @@ -11,7 +11,6 @@ import ( "strings" "golang.org/x/sys/unix" - "tailscale.com/types/ptr" "tailscale.com/util/lineiter" "tailscale.com/version/distro" ) @@ -26,8 +25,8 @@ func init() { } var ( - lazyVersionMeta = &lazyAtomicValue[versionMeta]{f: ptr.To(linuxVersionMeta)} - lazyOSVersion = &lazyAtomicValue[string]{f: ptr.To(osVersionLinux)} + lazyVersionMeta = &lazyAtomicValue[versionMeta]{f: new(linuxVersionMeta)} + lazyOSVersion = &lazyAtomicValue[string]{f: new(osVersionLinux)} ) type versionMeta struct { @@ -69,7 +68,7 @@ func deviceModelLinux() string { } func getQnapQtsVersion(versionInfo string) string { - for _, field := range strings.Fields(versionInfo) { + for field := range strings.FieldsSeq(versionInfo) { if suffix, ok := strings.CutPrefix(field, "QTSFW_"); ok { return suffix } @@ -111,11 +110,11 @@ func linuxVersionMeta() (meta versionMeta) { if err != nil { break } - eq := bytes.IndexByte(line, '=') - if eq == -1 { + before, after, ok := bytes.Cut(line, []byte{'='}) + if !ok { continue } - k, v := string(line[:eq]), strings.Trim(string(line[eq+1:]), `"'`) + k, v := string(before), strings.Trim(string(after), `"'`) m[k] = v } diff --git a/hostinfo/hostinfo_uname.go b/hostinfo/hostinfo_uname.go index b358c0e2cb108..0185da49d8bc9 100644 --- a/hostinfo/hostinfo_uname.go +++ b/hostinfo/hostinfo_uname.go @@ -9,14 +9,13 @@ import ( "runtime" "golang.org/x/sys/unix" - "tailscale.com/types/ptr" ) func init() { unameMachine = lazyUnameMachine.Get } -var lazyUnameMachine = &lazyAtomicValue[string]{f: ptr.To(unameMachineUnix)} +var lazyUnameMachine = &lazyAtomicValue[string]{f: new(unameMachineUnix)} func unameMachineUnix() string { switch runtime.GOOS { diff --git a/hostinfo/hostinfo_windows.go b/hostinfo/hostinfo_windows.go index 5e0b340919e34..59b57433e0c65 100644 --- a/hostinfo/hostinfo_windows.go +++ b/hostinfo/hostinfo_windows.go @@ -11,7 +11,6 @@ import ( "golang.org/x/sys/windows" "golang.org/x/sys/windows/registry" - "tailscale.com/types/ptr" "tailscale.com/util/winutil" "tailscale.com/util/winutil/winenv" ) @@ -23,9 +22,9 @@ func init() { } var ( - lazyDistroName = &lazyAtomicValue[string]{f: ptr.To(distroNameWindows)} - lazyOSVersion = &lazyAtomicValue[string]{f: ptr.To(osVersionWindows)} - lazyPackageType = &lazyAtomicValue[string]{f: ptr.To(packageTypeWindows)} + lazyDistroName = &lazyAtomicValue[string]{f: new(distroNameWindows)} + lazyOSVersion = &lazyAtomicValue[string]{f: new(osVersionWindows)} + lazyPackageType = &lazyAtomicValue[string]{f: new(packageTypeWindows)} ) func distroNameWindows() string { diff --git a/ipn/auditlog/auditlog.go b/ipn/auditlog/auditlog.go index cc6b43cbdba08..0d6bd278d1996 100644 --- a/ipn/auditlog/auditlog.go +++ b/ipn/auditlog/auditlog.go @@ -69,8 +69,11 @@ type Opts struct { // IsRetryableError returns true if the given error is retryable // See [controlclient.apiResponseError]. Potentially retryable errors implement the Retryable() method. func IsRetryableError(err error) bool { - var retryable interface{ Retryable() bool } - return errors.As(err, &retryable) && retryable.Retryable() + retryable, ok := errors.AsType[interface { + error + Retryable() bool + }](err) + return ok && retryable.Retryable() } type backoffOpts struct { diff --git a/ipn/backend.go b/ipn/backend.go index 3183c8b5e7a4e..51617e08e575d 100644 --- a/ipn/backend.go +++ b/ipn/backend.go @@ -85,6 +85,13 @@ const ( NotifyHealthActions NotifyWatchOpt = 1 << 9 // if set, include PrimaryActions in health.State. Otherwise append the action URL to the text NotifyInitialSuggestedExitNode NotifyWatchOpt = 1 << 10 // if set, the first Notify message (sent immediately) will contain the current SuggestedExitNode if available + + NotifyInitialClientVersion NotifyWatchOpt = 1 << 11 // if set, the first Notify message (sent immediately) will contain the current ClientVersion if available and if update checks are enabled + + // NotifyPeerChanges, if set, causes netmap delta updates to be sent as [tailcfg.PeerChange] rather than a full NetMap. + // Full netmap responses from the control plane are still sent as a full NetMap. PeerChanges are only sent to sessions + // that have opted in to this mode. + NotifyPeerChanges NotifyWatchOpt = 1 << 12 ) // Notify is a communication from a backend (e.g. tailscaled) to a frontend @@ -110,8 +117,26 @@ type Notify struct { State *State // if non-nil, the new or current IPN state Prefs *PrefsView // if non-nil && Valid, the new or current preferences NetMap *netmap.NetworkMap // if non-nil, the new or current netmap - Engine *EngineStatus // if non-nil, the new or current wireguard stats - BrowseToURL *string // if non-nil, UI should open a browser right now + + // SelfChange, if non-nil, indicates that this node's own [tailcfg.Node] + // has changed: addresses, name, key expiry, capabilities, etc. It carries + // the new self node so reactive consumers (containerboot, kube agents, + // sniproxy, etc.) can read the current self state without watching the + // full netmap. + // + // Consumers that need additional state (peers, DNS config, packet + // filter) should react to SelfChange by fetching the relevant bits on + // demand via [LocalClient]. + SelfChange *tailcfg.Node `json:",omitzero"` + + // PeerChanges, if non-nil, is a list of [tailcfg.PeerChange] that have occurred since the last + // full netmap update. This is sent in lieu of a full NetMap when [NotifyPeerChanges] is set in + // the session's mask and a netmap update is derived from an incremental MapResponse. + // Full MapResponse updates from the control plane are sent as a full NetMap. + PeerChanges []*tailcfg.PeerChange `json:",omitzero"` + + Engine *EngineStatus // if non-nil, the new or current wireguard stats + BrowseToURL *string // if non-nil, UI should open a browser right now // FilesWaiting if non-nil means that files are buffered in // the Tailscale daemon and ready for local transfer to the @@ -182,6 +207,12 @@ func (n Notify) String() string { if n.NetMap != nil { sb.WriteString("NetMap{...} ") } + if n.SelfChange != nil { + fmt.Fprintf(&sb, "SelfChange(%v) ", n.SelfChange.StableID) + } + if n.PeerChanges != nil { + fmt.Fprintf(&sb, "PeerChanges(%d) ", len(n.PeerChanges)) + } if n.Engine != nil { fmt.Fprintf(&sb, "wg=%v ", *n.Engine) } diff --git a/ipn/conf.go b/ipn/conf.go index ef753a0b48544..de127a28a0d7b 100644 --- a/ipn/conf.go +++ b/ipn/conf.go @@ -4,6 +4,8 @@ package ipn import ( + "errors" + "fmt" "net/netip" "tailscale.com/tailcfg" @@ -101,12 +103,21 @@ func (c *ConfigVAlpha) ToPrefs() (MaskedPrefs, error) { mp.ExitNodeAllowLANAccessSet = true } if c.AdvertiseRoutes != nil { + var routeErrs []error + for _, route := range c.AdvertiseRoutes { + if route != route.Masked() { + routeErrs = append(routeErrs, fmt.Errorf("route %s has non-address bits set; expected %s", route, route.Masked())) + } + } + if err := errors.Join(routeErrs...); err != nil { + return mp, err + } mp.AdvertiseRoutes = c.AdvertiseRoutes mp.AdvertiseRoutesSet = true } if c.DisableSNAT != "" { mp.NoSNAT = c.DisableSNAT.EqualBool(true) - mp.NoSNAT = true + mp.NoSNATSet = true } if c.NoStatefulFiltering != "" { mp.NoStatefulFiltering = c.NoStatefulFiltering diff --git a/ipn/conf_test.go b/ipn/conf_test.go new file mode 100644 index 0000000000000..41b5c4506f827 --- /dev/null +++ b/ipn/conf_test.go @@ -0,0 +1,66 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package ipn + +import ( + "net/netip" + "testing" +) + +// TestConfigVAlpha_ToPrefs_AdvertiseRoutes tests that ToPrefs validates routes +// provided directly as netip.Prefix values (not parsed from JSON). +func TestConfigVAlpha_ToPrefs_AdvertiseRoutes(t *testing.T) { + tests := []struct { + name string + routes []netip.Prefix + wantErr bool + }{ + { + name: "valid_routes", + routes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("2001:db8::/32"), + }, + wantErr: false, + }, + { + name: "invalid_ipv4_route", + routes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.1/24"), + }, + wantErr: true, + }, + { + name: "invalid_ipv6_route", + routes: []netip.Prefix{ + netip.MustParsePrefix("2a01:4f9:c010:c015::1/64"), + }, + wantErr: true, + }, + { + name: "mixed_valid_and_invalid", + routes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.1.1/16"), + netip.MustParsePrefix("2001:db8::/32"), + netip.MustParsePrefix("2a01:4f9::1/64"), + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := ConfigVAlpha{ + Version: "alpha0", + AdvertiseRoutes: tt.routes, + } + + _, err := cfg.ToPrefs() + if (err != nil) != tt.wantErr { + t.Errorf("cfg.ToPrefs() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/ipn/desktop/sessions_windows.go b/ipn/desktop/sessions_windows.go index 6128548a51216..7d51900f9f730 100644 --- a/ipn/desktop/sessions_windows.go +++ b/ipn/desktop/sessions_windows.go @@ -510,10 +510,13 @@ func sessionWatcherWndProc(hWnd windows.HWND, msg uint32, wParam, lParam uintptr } func pumpThreadMessages() { - var msg _MSG - for getMessage(&msg, 0, 0, 0) != 0 { - translateMessage(&msg) - dispatchMessage(&msg) + var p runtime.Pinner + defer p.Unpin() + msg := &_MSG{} + p.Pin(msg) + for getMessage(msg, 0, 0, 0) != 0 { + translateMessage(msg) + dispatchMessage(msg) } } diff --git a/ipn/ipn_clone.go b/ipn/ipn_clone.go index 94aebefdfd73d..e179438cdcfcb 100644 --- a/ipn/ipn_clone.go +++ b/ipn/ipn_clone.go @@ -14,7 +14,6 @@ import ( "tailscale.com/types/opt" "tailscale.com/types/persist" "tailscale.com/types/preftype" - "tailscale.com/types/ptr" ) // Clone makes a deep copy of LoginProfile. @@ -25,6 +24,7 @@ func (src *LoginProfile) Clone() *LoginProfile { } dst := new(LoginProfile) *dst = *src + dst.UserProfile = *src.UserProfile.Clone() return dst } @@ -62,7 +62,7 @@ func (src *Prefs) Clone() *Prefs { } } if dst.RelayServerPort != nil { - dst.RelayServerPort = ptr.To(*src.RelayServerPort) + dst.RelayServerPort = new(*src.RelayServerPort) } dst.RelayServerStaticEndpoints = append(src.RelayServerStaticEndpoints[:0:0], src.RelayServerStaticEndpoints...) dst.Persist = src.Persist.Clone() @@ -122,7 +122,7 @@ func (src *ServeConfig) Clone() *ServeConfig { if v == nil { dst.TCP[k] = nil } else { - dst.TCP[k] = ptr.To(*v) + dst.TCP[k] = new(*v) } } } @@ -184,7 +184,7 @@ func (src *ServiceConfig) Clone() *ServiceConfig { if v == nil { dst.TCP[k] = nil } else { - dst.TCP[k] = ptr.To(*v) + dst.TCP[k] = new(*v) } } } diff --git a/ipn/ipn_view.go b/ipn/ipn_view.go index 90560cec0e195..4e9d46bda30a0 100644 --- a/ipn/ipn_view.go +++ b/ipn/ipn_view.go @@ -113,7 +113,7 @@ func (v LoginProfileView) Key() StateKey { return v.Đļ.Key } // UserProfile is the server provided UserProfile for this profile. // This is updated whenever the server provides a new UserProfile. -func (v LoginProfileView) UserProfile() tailcfg.UserProfile { return v.Đļ.UserProfile } +func (v LoginProfileView) UserProfile() tailcfg.UserProfileView { return v.Đļ.UserProfile.View() } // NodeID is the NodeID of the node that this profile is logged into. // This should be stable across tagging and untagging nodes. diff --git a/ipn/ipnext/ipnext.go b/ipn/ipnext/ipnext.go index 6dea49939af91..5ca50498a81ab 100644 --- a/ipn/ipnext/ipnext.go +++ b/ipn/ipnext/ipnext.go @@ -19,8 +19,10 @@ import ( "tailscale.com/tailcfg" "tailscale.com/tsd" "tailscale.com/tstime" + "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/mapx" + "tailscale.com/types/views" "tailscale.com/wgengine/filter" ) @@ -202,6 +204,16 @@ type Host interface { // NodeBackend returns the [NodeBackend] for the currently active node // (which is approximately the same as the current profile). NodeBackend() NodeBackend + + // AuthReconfigAsync asynchronously pushes a new configuration into wgengine, + // if engine updates are not currently blocked, based on the cached netmap and + // user prefs. The reconfiguration is applied to [ipnlocal.LocalBackend]'s currently + // active node at the time of execution. + // + // AuthReconfigAsync should not be called at a high rate (i.e., more often + // than prefs and netmap changes), except in experimental or proof-of-concept + // contexts, since reconfiguration is known to be slow. + AuthReconfigAsync() } // SafeBackend is a subset of the [ipnlocal.LocalBackend] type's methods that @@ -382,6 +394,42 @@ type Hooks struct { // Filter contains hooks for the packet filter. // See [filter.Filter] for details on how these hooks are invoked. Filter FilterHooks + + // ExtraWireGuardAllowedIPs is called with each peer's public key + // from the initial [wgcfg.Config], and returns a view of prefixes to + // append to each peer's AllowedIPs. + // + // The extra AllowedIPs are added after the [router.Config] is generated, but + // before the WireGuard config is sent to the engine, so the extra IPs are + // given to WireGuard, but not the OS routing table. + // + // The prefixes returned from the hook should not contain duplicates, either + // internally, or with netmap peer prefixes. Returned prefixes should only + // contain host routes, and not contain default or subnet routes. + // Subsequent calls that return an unchanged set of prefixes for a given peer, + // should return the prefixes in the same order for that peer, + // to prevent configuration churn. + // + // The returned slice should not be mutated by the extension after it is returned. + // + // The hook is called with LocalBackend's mutex locked. + // + // TODO(#17858): This hook may not be needed and can possibly be replaced by + // new hooks that fit into the new architecture that make use of new + // WireGuard APIs. + ExtraWireGuardAllowedIPs feature.Hook[func(key.NodePublic) views.Slice[netip.Prefix]] + + // ExtraRouterConfigRoutes returns a view of prefixes to append to [router.Config.Routes]. + // + // Routes goes through the WireGuard engine which makes efforts to avoid + // unnecessary reconfiguration by checking that things have actually changed. + // So implementors should make sure that the order of the prefixes is stable + // and that we don't have duplicate entries. + // + // The returned slice should not be mutated by the extension after it is returned. + // + // The hook is called with LocalBackend's mutex locked. + ExtraRouterConfigRoutes feature.Hook[func() views.Slice[netip.Prefix]] } // FilterHooks contains hooks that extensions can use to customize the packet diff --git a/ipn/ipnlocal/breaktcp_linux.go b/ipn/ipnlocal/breaktcp_linux.go index 0ba9ed6d78f19..1d7ea0f314b11 100644 --- a/ipn/ipnlocal/breaktcp_linux.go +++ b/ipn/ipnlocal/breaktcp_linux.go @@ -15,7 +15,7 @@ func init() { func breakTCPConnsLinux() error { var matched int - for fd := 0; fd < 1000; fd++ { + for fd := range 1000 { _, err := unix.GetsockoptTCPInfo(fd, unix.IPPROTO_TCP, unix.TCP_INFO) if err == nil { matched++ diff --git a/ipn/ipnlocal/bus.go b/ipn/ipnlocal/bus.go index 6061f7223988d..8be50801001b9 100644 --- a/ipn/ipnlocal/bus.go +++ b/ipn/ipnlocal/bus.go @@ -8,6 +8,7 @@ import ( "time" "tailscale.com/ipn" + "tailscale.com/tailcfg" "tailscale.com/tstime" ) @@ -116,8 +117,8 @@ func (s *rateLimitingBusSender) Run(ctx context.Context, ch <-chan *ipn.Notify) } } -// mergeBoringNotify merges new notify 'src' into possibly-nil 'dst', -// either mutating 'dst' or allocating a new one if 'dst' is nil, +// mergeBoringNotify merges new notify src into possibly-nil dst, +// either mutating dst or allocating a new one if dst is nil, // returning the merged result. // // dst and src must both be "boring" (i.e. not notable per isNotifiableNotify). @@ -127,6 +128,9 @@ func mergeBoringNotifies(dst, src *ipn.Notify) *ipn.Notify { } if src.NetMap != nil { dst.NetMap = src.NetMap + dst.PeerChanges = nil // full netmap supersedes any accumulated deltas + } else if src.PeerChanges != nil { + dst.PeerChanges = mergePeerChanges(dst.PeerChanges, src.PeerChanges) } if src.Engine != nil { dst.Engine = src.Engine @@ -134,6 +138,55 @@ func mergeBoringNotifies(dst, src *ipn.Notify) *ipn.Notify { return dst } +// mergePeerChanges merges new peer changes from src into dst, either +// mutating dst or allocating a new slice if dst is nil, returning the merged result. +// Values in src override those in dst for the same NodeID. +func mergePeerChanges(dst, src []*tailcfg.PeerChange) []*tailcfg.PeerChange { + idxByNode := make(map[tailcfg.NodeID]int, len(dst)) + for i, d := range dst { + idxByNode[d.NodeID] = i + } + + for _, nd := range src { + if oi, ok := idxByNode[nd.NodeID]; ok { + dst[oi] = mergePeerChangeForIpnBus(dst[oi], nd) + continue + } + idxByNode[nd.NodeID] = len(dst) + dst = append(dst, nd) + } + return dst +} + +// mergePeerChangeForIpnBus merges new with old, returning the result. +// Fields set in new override those in old; fields only set in old are preserved. +func mergePeerChangeForIpnBus(old, new *tailcfg.PeerChange) *tailcfg.PeerChange { + merged := *old + + // This is a subset of PeerChange that reflects only the fields that can + // be changed via a NodeMutation. If future fields can be updated via + // NodeMutations from map responses (and they are relevant to the ipn bus), then + // they should be added here and merged in the same way. + if new.DERPRegion != 0 { + // netmap.NodeMutationDerpHome + merged.DERPRegion = new.DERPRegion + } + if new.Online != nil { + // netmap.NodeMutationOnline + merged.Online = new.Online + } + if new.LastSeen != nil { + // netmap.NodeMutationLastSeen + merged.LastSeen = new.LastSeen + } + if new.Endpoints != nil { + // netmap.NodeMutationEndpoints + merged.Endpoints = new.Endpoints + } + + return &merged +} + // isNotableNotify reports whether n is a "notable" notification that // should be sent on the IPN bus immediately (e.g. to GUIs) without // rate limiting it for a few seconds. @@ -152,6 +205,7 @@ func isNotableNotify(n *ipn.Notify) bool { n.Prefs != nil || n.ErrMessage != nil || n.LoginFinished != nil || + n.SelfChange != nil || !n.DriveShares.IsNil() || n.Health != nil || len(n.IncomingFiles) > 0 || diff --git a/ipn/ipnlocal/bus_test.go b/ipn/ipnlocal/bus_test.go index 27ffebcdd570e..048e5bff4d6de 100644 --- a/ipn/ipnlocal/bus_test.go +++ b/ipn/ipnlocal/bus_test.go @@ -12,6 +12,7 @@ import ( "tailscale.com/drive" "tailscale.com/ipn" + "tailscale.com/tailcfg" "tailscale.com/tstest" "tailscale.com/tstime" "tailscale.com/types/logger" @@ -29,24 +30,25 @@ func TestIsNotableNotify(t *testing.T) { {"empty", &ipn.Notify{}, false}, {"version", &ipn.Notify{Version: "foo"}, false}, {"netmap", &ipn.Notify{NetMap: new(netmap.NetworkMap)}, false}, + {"peerchanges", &ipn.Notify{PeerChanges: []*tailcfg.PeerChange{{}}}, false}, {"engine", &ipn.Notify{Engine: new(ipn.EngineStatus)}, false}, + {"selfchange", &ipn.Notify{SelfChange: &tailcfg.Node{}}, true}, } // Then for all other fields, assume they're notable. // We use reflect to catch fields that might be added in the future without // remembering to update the [isNotableNotify] function. rt := reflect.TypeFor[ipn.Notify]() - for i := range rt.NumField() { + for sf := range rt.Fields() { n := &ipn.Notify{} - sf := rt.Field(i) switch sf.Name { - case "_", "NetMap", "Engine", "Version": + case "_", "NetMap", "PeerChanges", "SelfChange", "Engine", "Version": // Already covered above or not applicable. continue case "DriveShares": n.DriveShares = views.SliceOfViews[*drive.Share, drive.ShareView](make([]*drive.Share, 1)) default: - rf := reflect.ValueOf(n).Elem().Field(i) + rf := reflect.ValueOf(n).Elem().FieldByIndex(sf.Index) switch rf.Kind() { case reflect.Pointer: rf.Set(reflect.New(rf.Type().Elem())) @@ -64,7 +66,7 @@ func TestIsNotableNotify(t *testing.T) { notify *ipn.Notify want bool }{ - name: "field-" + rt.Field(i).Name, + name: "field-" + sf.Name, notify: n, want: true, }) @@ -218,3 +220,103 @@ func TestRateLimitingBusSender(t *testing.T) { st.s.Run(ctx, incoming) }) } + +func TestMergePeerChanges(t *testing.T) { + online := true + offline := false + + t.Run("no_overlap_appends", func(t *testing.T) { + old := []*tailcfg.PeerChange{ + {NodeID: 1, DERPRegion: 1}, + } + new := []*tailcfg.PeerChange{ + {NodeID: 2, DERPRegion: 2}, + } + got := mergePeerChanges(old, new) + if len(got) != 2 { + t.Fatalf("len = %d; want 2", len(got)) + } + if got[0].NodeID != 1 || got[1].NodeID != 2 { + t.Errorf("got NodeIDs %d, %d; want 1, 2", got[0].NodeID, got[1].NodeID) + } + }) + + t.Run("overlap_merges", func(t *testing.T) { + old := []*tailcfg.PeerChange{ + {NodeID: 1, DERPRegion: 1, Online: &online}, + {NodeID: 2, DERPRegion: 10}, + } + new := []*tailcfg.PeerChange{ + {NodeID: 1, DERPRegion: 5, Online: &offline}, + } + got := mergePeerChanges(old, new) + if len(got) != 2 { + t.Fatalf("len = %d; want 2 (merged, not appended)", len(got)) + } + if got[0].DERPRegion != 5 { + t.Errorf("DERPRegion = %d; want 5 (from new)", got[0].DERPRegion) + } + if *got[0].Online != false { + t.Errorf("Online = %v; want false (from new)", *got[0].Online) + } + // Node 2 should be untouched. + if got[1].NodeID != 2 || got[1].DERPRegion != 10 { + t.Errorf("node 2 was modified unexpectedly") + } + }) + + t.Run("partial_overlap_merges_and_appends", func(t *testing.T) { + old := []*tailcfg.PeerChange{ + {NodeID: 1, DERPRegion: 1}, + } + new := []*tailcfg.PeerChange{ + {NodeID: 1, DERPRegion: 2}, + {NodeID: 3, DERPRegion: 30}, + } + got := mergePeerChanges(old, new) + if len(got) != 2 { + t.Fatalf("len = %d; want 2", len(got)) + } + if got[0].NodeID != 1 || got[0].DERPRegion != 2 { + t.Errorf("node 1: DERPRegion = %d; want 2", got[0].DERPRegion) + } + if got[1].NodeID != 3 || got[1].DERPRegion != 30 { + t.Errorf("node 3: DERPRegion = %d; want 30", got[1].DERPRegion) + } + }) + + t.Run("preserves_old_fields_on_merge", func(t *testing.T) { + old := []*tailcfg.PeerChange{ + {NodeID: 1, DERPRegion: 1, Online: &online, Cap: 10}, + } + new := []*tailcfg.PeerChange{ + {NodeID: 1, Online: &offline}, + } + got := mergePeerChanges(old, new) + if len(got) != 1 { + t.Fatalf("len = %d; want 1", len(got)) + } + if got[0].DERPRegion != 1 { + t.Errorf("DERPRegion = %d; want 1 (preserved from old)", got[0].DERPRegion) + } + if got[0].Cap != 10 { + t.Errorf("Cap = %d; want 10 (preserved from old)", got[0].Cap) + } + if *got[0].Online != false { + t.Errorf("Online = %v; want false (from new)", *got[0].Online) + } + }) + + t.Run("nil_old", func(t *testing.T) { + new := []*tailcfg.PeerChange{ + {NodeID: 1, DERPRegion: 1}, + } + got := mergePeerChanges(nil, new) + if len(got) != 1 { + t.Fatalf("len = %d; want 1", len(got)) + } + if got[0].NodeID != 1 { + t.Errorf("NodeID = %d; want 1", got[0].NodeID) + } + }) +} diff --git a/ipn/ipnlocal/c2n.go b/ipn/ipnlocal/c2n.go index ccce2a65d99e6..bf8cf2e038a64 100644 --- a/ipn/ipnlocal/c2n.go +++ b/ipn/ipnlocal/c2n.go @@ -27,6 +27,7 @@ import ( "tailscale.com/util/goroutines" "tailscale.com/util/httpm" "tailscale.com/util/set" + "tailscale.com/util/testenv" "tailscale.com/version" ) @@ -44,9 +45,6 @@ func init() { // several candidate nodes is reachable and actually alive. RegisterC2N("/echo", handleC2NEcho) } - if buildfeatures.HasSSH { - RegisterC2N("/ssh/usernames", handleC2NSSHUsernames) - } if buildfeatures.HasLogTail { RegisterC2N("POST /logtail/flush", handleC2NLogtailFlush) } @@ -290,26 +288,6 @@ func handleC2NPprof(b *LocalBackend, w http.ResponseWriter, r *http.Request) { c2nPprof(w, r, profile) } -func handleC2NSSHUsernames(b *LocalBackend, w http.ResponseWriter, r *http.Request) { - if !buildfeatures.HasSSH { - http.Error(w, feature.ErrUnavailable.Error(), http.StatusNotImplemented) - return - } - var req tailcfg.C2NSSHUsernamesRequest - if r.Method == "POST" { - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - } - res, err := b.getSSHUsernames(&req) - if err != nil { - http.Error(w, err.Error(), 500) - return - } - writeJSON(w, res) -} - func handleC2NSockStats(b *LocalBackend, w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/plain") if b.sockstatLogger == nil { @@ -346,3 +324,10 @@ func handleC2NSetNetfilterKind(b *LocalBackend, w http.ResponseWriter, r *http.R w.WriteHeader(http.StatusNoContent) } + +// HandleC2NForTest calls [handleC2N], for use by feature/ packages that +// register C2N handlers and want to test them. +func (b *LocalBackend) HandleC2NForTest(w http.ResponseWriter, r *http.Request) { + testenv.AssertInTest() + b.handleC2N(w, r) +} diff --git a/ipn/ipnlocal/c2n_test.go b/ipn/ipnlocal/c2n_test.go index 810d6765b45e2..e5b15e79a2577 100644 --- a/ipn/ipnlocal/c2n_test.go +++ b/ipn/ipnlocal/c2n_test.go @@ -16,6 +16,7 @@ import ( "testing" "time" + "tailscale.com/health" "tailscale.com/ipn/store/mem" "tailscale.com/tailcfg" "tailscale.com/tstest" @@ -23,6 +24,7 @@ import ( "tailscale.com/types/logger" "tailscale.com/types/netmap" "tailscale.com/types/views" + "tailscale.com/util/eventbus/eventbustest" "tailscale.com/util/must" gcmp "github.com/google/go-cmp/cmp" @@ -33,6 +35,7 @@ func TestHandleC2NTLSCertStatus(t *testing.T) { b := &LocalBackend{ store: &mem.Store{}, varRoot: t.TempDir(), + health: health.NewTracker(eventbustest.NewBus(t)), } certDir, err := b.certDir() if err != nil { @@ -63,7 +66,7 @@ func TestHandleC2NTLSCertStatus(t *testing.T) { want *tailcfg.C2NTLSCertInfo }{ { - name: "no domain", + name: "no-domain", wantStatus: 400, wantError: "no 'domain'\n", }, diff --git a/ipn/ipnlocal/cert.go b/ipn/ipnlocal/cert.go index efab9db7aad6e..eab70b295e5bd 100644 --- a/ipn/ipnlocal/cert.go +++ b/ipn/ipnlocal/cert.go @@ -909,7 +909,7 @@ func (b *LocalBackend) resolveCertDomain(domain string) (string, error) { } // Read the netmap once to get both CertDomains and capabilities atomically. - nm := b.NetMap() + nm := b.NetMapNoPeers() if nm == nil { return "", errors.New("no netmap available") } diff --git a/ipn/ipnlocal/cert_test.go b/ipn/ipnlocal/cert_test.go index cc9146ae1e055..56d6df77f3128 100644 --- a/ipn/ipnlocal/cert_test.go +++ b/ipn/ipnlocal/cert_test.go @@ -39,25 +39,29 @@ func TestCertRequest(t *testing.T) { } tests := []struct { + name string domain string wantSANs []string }{ { + name: "example-com", domain: "example.com", wantSANs: []string{"example.com"}, }, { + name: "wildcard-example-com", domain: "*.example.com", wantSANs: []string{"*.example.com", "example.com"}, }, { + name: "wildcard-foo-bar-com", domain: "*.foo.bar.com", wantSANs: []string{"*.foo.bar.com", "foo.bar.com"}, }, } for _, tt := range tests { - t.Run(tt.domain, func(t *testing.T) { + t.Run(tt.name, func(t *testing.T) { csrDER, err := certRequest(key, tt.domain, nil) if err != nil { t.Fatalf("certRequest: %v", err) @@ -365,19 +369,19 @@ func TestShouldStartDomainRenewal(t *testing.T) { wantErr string }{ { - name: "should renew", + name: "should-renew", notBefore: now.AddDate(0, 0, -89), lifetime: 90 * 24 * time.Hour, want: true, }, { - name: "short-lived renewal", + name: "short-lived-renewal", notBefore: now.AddDate(0, 0, -7), lifetime: 10 * 24 * time.Hour, want: true, }, { - name: "no renew", + name: "no-renew", notBefore: now.AddDate(0, 0, -59), // 59 days ago == not 2/3rds of the way through 90 days yet lifetime: 90 * 24 * time.Hour, want: false, @@ -515,8 +519,11 @@ func TestGetCertPEMWithValidity(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + tstest.AssertNotParallel(t) if tt.readOnlyMode { envknob.Setenv("TS_CERT_SHARE_MODE", "ro") + } else { + envknob.Setenv("TS_CERT_SHARE_MODE", "") } os.RemoveAll(certDir) diff --git a/ipn/ipnlocal/diskcache.go b/ipn/ipnlocal/diskcache.go index 0b1b7b4487bd1..3235869e6fab3 100644 --- a/ipn/ipnlocal/diskcache.go +++ b/ipn/ipnlocal/diskcache.go @@ -4,6 +4,10 @@ package ipnlocal import ( + "context" + "errors" + "fmt" + "tailscale.com/feature/buildfeatures" "tailscale.com/ipn/ipnlocal/netmapcache" "tailscale.com/types/netmap" @@ -31,7 +35,19 @@ func (b *LocalBackend) writeNetmapToDiskLocked(nm *netmap.NetworkMap) error { b.diskCache.cache = netmapcache.NewCache(netmapcache.FileStore(dir)) b.diskCache.dir = dir } - return b.diskCache.cache.Store(b.currentNode().Context(), nm) + + // Set the homeDERP on the self node before saving. The self node homeDERP is + // generally not used since the homeDERP for self is stored in magicsock, but + // to be able to load it during loading the cache, we use the existing field + // to save it. + + // Make a shallow copy and mutate a copy of the selfNode. + nmCopy := *nm + selfNode := nm.SelfNode.AsStruct() + selfNode.HomeDERP = int(b.currentNode().homeDERP.Load()) + nmCopy.SelfNode = selfNode.View() + + return b.diskCache.cache.Store(b.currentNode().Context(), &nmCopy) } func (b *LocalBackend) loadDiskCacheLocked() (om *netmap.NetworkMap, ok bool) { @@ -54,3 +70,60 @@ func (b *LocalBackend) loadDiskCacheLocked() (om *netmap.NetworkMap, ok bool) { } return nm, true } + +// discardDiskCacheLocked removes a cached network map for the current node, if +// one exists, and disables the cache. +func (b *LocalBackend) discardDiskCacheLocked() { + if !buildfeatures.HasCacheNetMap { + return + } + if b.diskCache.cache == nil { + return // nothing to do, we do not have a cache + } + // Reaching here, we have a cache directory that needs to be purged. + // Log errors but do not fail for them. + store := netmapcache.FileStore(b.diskCache.dir) + if err := b.clearStoreLocked(b.currentNode().Context(), store); err != nil { + b.logf("clearing netmap cache: %v", err) + } + b.diskCache = diskCache{} // drop in-memory state +} + +// clearStoreLocked discards all the keys in the specified store. +func (b *LocalBackend) clearStoreLocked(ctx context.Context, store netmapcache.Store) error { + var errs []error + for key, err := range store.List(ctx, "") { + if err != nil { + errs = append(errs, fmt.Errorf("list cache contest: %w", err)) + break + } + if err := store.Remove(ctx, key); err != nil { + errs = append(errs, fmt.Errorf("discard cache key %q: %w", key, err)) + } + } + return errors.Join(errs...) +} + +// ClearNetmapCache discards stored netmap caches (if any) for profiles for the +// current user of b. It also drops any cache from the active backend session, +// if there is one. +func (b *LocalBackend) ClearNetmapCache(ctx context.Context) error { + if !buildfeatures.HasCacheNetMap { + return nil // disabled + } + + b.mu.Lock() + defer b.mu.Unlock() + + var errs []error + for _, p := range b.pm.Profiles() { + store := netmapcache.FileStore(b.profileDataPathLocked(p.ID(), "netmap-cache")) + err := b.clearStoreLocked(ctx, store) + if err != nil { + errs = append(errs, fmt.Errorf("clear netmap cache for profile %q: %w", p.ID(), err)) + } + } + + b.diskCache = diskCache{} // drop in-memory state + return errors.Join(errs...) +} diff --git a/ipn/ipnlocal/diskcache_test.go b/ipn/ipnlocal/diskcache_test.go new file mode 100644 index 0000000000000..748ff6a408e64 --- /dev/null +++ b/ipn/ipnlocal/diskcache_test.go @@ -0,0 +1,229 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package ipnlocal + +import ( + "net/netip" + "testing" + + "tailscale.com/tailcfg" + "tailscale.com/tstest" + "tailscale.com/types/netmap" + "tailscale.com/util/eventbus" + "tailscale.com/wgengine/magicsock" +) + +// newCacheTestNetmap returns a minimal valid netmap suitable for testing disk +// cache operations. +func newCacheTestNetmap() *netmap.NetworkMap { + return &netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{ + Name: "test-node.ts.net", + User: tailcfg.UserID(1), + Addresses: []netip.Prefix{ + netip.MustParsePrefix("100.64.0.1/32"), + }, + }).View(), + UserProfiles: map[tailcfg.UserID]tailcfg.UserProfileView{ + tailcfg.UserID(1): (&tailcfg.UserProfile{ + LoginName: "user@example.com", + DisplayName: "Test User", + }).View(), + }, + DERPMap: &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 1: {}, + 2: {}, + 3: {}, + 4: {}, + 5: {}, + 6: {}, + 7: {}, + 8: {}, + 9: {}, + 10: {}, + 11: {}, + }, + }, + } +} + +func TestWriteAndLoadHomeDERP(t *testing.T) { + b := newTestBackend(t) + + nm := newCacheTestNetmap() + b.currentNode().SetNetMap(nm) + + const wantDERP = 7 + b.currentNode().homeDERP.Store(wantDERP) + + b.mu.Lock() + defer b.mu.Unlock() + + if err := b.writeNetmapToDiskLocked(nm); err != nil { + t.Fatalf("writeNetmapToDiskLocked: %v", err) + } + + loaded, ok := b.loadDiskCacheLocked() + if !ok { + t.Fatal("loadDiskCacheLocked returned ok=false") + } + if !loaded.SelfNode.Valid() { + t.Fatal("loaded netmap SelfNode is invalid") + } + if got := loaded.SelfNode.HomeDERP(); got != wantDERP { + t.Errorf("loaded SelfNode.HomeDERP() = %d, want %d", got, wantDERP) + } +} + +func TestOnHomeDERPUpdate(t *testing.T) { + t.Run("normal_derp_change", func(t *testing.T) { + b := newTestBackend(t) + done := make(chan struct{}) + tstest.Replace(t, &testOnlyHomeDERPUpdate, func() { close(done) }) + + nm := newCacheTestNetmap() + b.currentNode().SetNetMap(nm) + + // Publish a HomeDERPChanged event via the backend's event bus. + bus := b.Sys().Bus.Get() + ec := bus.Client("test.TestOnHomeDERPUpdate") + pub := eventbus.Publish[magicsock.HomeDERPChanged](ec) + + const wantDERP = 11 + pub.Publish(magicsock.HomeDERPChanged{Old: 0, New: wantDERP}) + <-done + + if got := b.currentNode().homeDERP.Load(); got != wantDERP { + t.Errorf("b.homeDERP = %d, want %d", got, wantDERP) + } + + // Verify the value was persisted to the disk cache. + b.mu.Lock() + defer b.mu.Unlock() + loaded, ok := b.loadDiskCacheLocked() + if !ok { + t.Fatal("loadDiskCacheLocked returned ok=false after homeDERP update") + } + if got := loaded.SelfNode.HomeDERP(); got != wantDERP { + t.Errorf("cached SelfNode.HomeDERP() = %d, want %d", got, wantDERP) + } + }) + t.Run("old_does_not_match", func(t *testing.T) { + b := newTestBackend(t) + done := make(chan struct{}) + tstest.Replace(t, &testOnlyHomeDERPUpdate, func() { close(done) }) + + const setDERP = 11 + const wantDERP = 4 + + nm := newCacheTestNetmap() + selfNode := nm.SelfNode.AsStruct() + selfNode.HomeDERP = wantDERP + nm.SelfNode = selfNode.View() + b.currentNode().SetNetMap(nm) + b.currentNode().homeDERP.Store(wantDERP) + + // Write an initial cache entry so we can verify it is not overwritten. + b.mu.Lock() + if err := b.writeNetmapToDiskLocked(nm); err != nil { + b.mu.Unlock() + t.Fatalf("setup writeNetmapToDiskLocked: %v", err) + } + b.mu.Unlock() + + // Publish a HomeDERPChanged event via the backend's event bus. + bus := b.Sys().Bus.Get() + ec := bus.Client("test.TestOnHomeDERPUpdate") + pub := eventbus.Publish[magicsock.HomeDERPChanged](ec) + pub.Publish(magicsock.HomeDERPChanged{Old: wantDERP + 1, New: setDERP}) + <-done + + if got := b.currentNode().homeDERP.Load(); got != wantDERP { + t.Errorf("b.homeDERP = %d, wanted no change %d", got, wantDERP) + } + + // Verify the cache still exists and still holds the original value. + b.mu.Lock() + defer b.mu.Unlock() + loaded, ok := b.loadDiskCacheLocked() + if !ok { + t.Fatal("loadDiskCacheLocked returned ok=false; expected cache to still exist") + } + if got := loaded.SelfNode.HomeDERP(); got != wantDERP { + t.Errorf("cached SelfNode.HomeDERP() = %d after rejected event, want original %d", got, wantDERP) + } + }) + t.Run("new_does_not_exist_in_map", func(t *testing.T) { + b := newTestBackend(t) + done := make(chan struct{}) + tstest.Replace(t, &testOnlyHomeDERPUpdate, func() { close(done) }) + + const setDERP = 111 + const wantDERP = 4 + + nm := newCacheTestNetmap() + selfNode := nm.SelfNode.AsStruct() + selfNode.HomeDERP = wantDERP + nm.SelfNode = selfNode.View() + b.currentNode().SetNetMap(nm) + b.currentNode().homeDERP.Store(wantDERP) + + // Write an initial cache entry so we can verify it is not overwritten. + b.mu.Lock() + if err := b.writeNetmapToDiskLocked(nm); err != nil { + b.mu.Unlock() + t.Fatalf("setup writeNetmapToDiskLocked: %v", err) + } + b.mu.Unlock() + + // Publish a HomeDERPChanged event via the backend's event bus. + // Old matches the stored homeDERP so only the "new region not in map" + // guard is exercised. + bus := b.Sys().Bus.Get() + ec := bus.Client("test.TestOnHomeDERPUpdate") + pub := eventbus.Publish[magicsock.HomeDERPChanged](ec) + pub.Publish(magicsock.HomeDERPChanged{Old: wantDERP, New: setDERP}) + <-done + + if got := b.currentNode().homeDERP.Load(); got != wantDERP { + t.Errorf("b.homeDERP = %d, wanted no change %d", got, wantDERP) + } + + // Verify the cache still exists and still holds the original value. + b.mu.Lock() + defer b.mu.Unlock() + loaded, ok := b.loadDiskCacheLocked() + if !ok { + t.Fatal("loadDiskCacheLocked returned ok=false; expected cache to still exist") + } + if got := loaded.SelfNode.HomeDERP(); got != wantDERP { + t.Errorf("cached SelfNode.HomeDERP() = %d after rejected event, want original %d", got, wantDERP) + } + }) +} + +func TestWriteNetmapDoesNotMutateOriginal(t *testing.T) { + b := newTestBackend(t) + + nm := newCacheTestNetmap() + b.currentNode().SetNetMap(nm) + + originalDERP := nm.SelfNode.HomeDERP() // expected to be 0 initially + + const storeDERP = 5 + b.currentNode().homeDERP.Store(storeDERP) + + b.mu.Lock() + defer b.mu.Unlock() + + if err := b.writeNetmapToDiskLocked(nm); err != nil { + t.Fatalf("writeNetmapToDiskLocked: %v", err) + } + + // The original netmap must not have been mutated. + if got := nm.SelfNode.HomeDERP(); got != originalDERP { + t.Errorf("original nm.SelfNode.HomeDERP() = %d after write, want %d (original was mutated)", got, originalDERP) + } +} diff --git a/ipn/ipnlocal/drive.go b/ipn/ipnlocal/drive.go index 485114eae9d27..110ffff2a765d 100644 --- a/ipn/ipnlocal/drive.go +++ b/ipn/ipnlocal/drive.go @@ -303,18 +303,19 @@ func (b *LocalBackend) updateDrivePeersLocked(nm *netmap.NetworkMap) { } func (b *LocalBackend) driveRemotesFromPeers(nm *netmap.NetworkMap) []*drive.Remote { - b.logf("[v1] taildrive: setting up drive remotes from peers") + b.logf("[v1] taildrive: setting up drive remotes from %d peers", len(nm.Peers)) driveRemotes := make([]*drive.Remote, 0, len(nm.Peers)) for _, p := range nm.Peers { peer := p peerID := peer.ID() peerKey := peer.Key().ShortString() - b.logf("[v1] taildrive: appending remote for peer %s", peerKey) + peerName := peer.DisplayName(false) + driveRemotes = append(driveRemotes, &drive.Remote{ - Name: p.DisplayName(false), + Name: peerName, URL: func() string { url := fmt.Sprintf("%s/%s", b.currentNode().PeerAPIBase(peer), taildrivePrefix[1:]) - b.logf("[v2] taildrive: url for peer %s: %s", peerKey, url) + b.logf("[v2] taildrive: url for peer %s (%s): %s", peerKey, peerName, url) return url }, Available: func() bool { @@ -325,7 +326,7 @@ func (b *LocalBackend) driveRemotesFromPeers(nm *netmap.NetworkMap) []*drive.Rem cn := b.currentNode() peer, ok := cn.NodeByID(peerID) if !ok { - b.logf("[v2] taildrive: Available(): peer %s not found", peerKey) + b.logf("[v2] taildrive: peer %s (%s, id=%v) not found", peerKey, peerName, peerID) return false } @@ -338,26 +339,25 @@ func (b *LocalBackend) driveRemotesFromPeers(nm *netmap.NetworkMap) []*drive.Rem // The netmap.Peers slice is not updated in all cases. // It should be fixed now that we use PeerByIDOk. if !peer.Online().Get() { - b.logf("[v2] taildrive: Available(): peer %s offline", peerKey) + b.logf("[v2] taildrive: peer %s (%s, id=%v) offline", peerKey, peerName, peerID) return false } - - if b.currentNode().PeerAPIBase(peer) == "" { - b.logf("[v2] taildrive: Available(): peer %s PeerAPI unreachable", peerKey) + if cn.PeerAPIBase(peer) == "" { + b.logf("[v2] taildrive: peer %s (%s, id=%v) PeerAPI unreachable", peerKey, peerName, peerID) return false } - // Check that the peer is allowed to share with us. if cn.PeerHasCap(peer, tailcfg.PeerCapabilityTaildriveSharer) { - b.logf("[v2] taildrive: Available(): peer %s available", peerKey) + b.logf("[v2] taildrive: peer %s (%s, id=%v) available", peerKey, peerName, peerID) return true } - b.logf("[v2] taildrive: Available(): peer %s not allowed to share", peerKey) + b.logf("[v2] taildrive: peer %s (%s, id=%v) not allowed to share", peerKey, peerName, peerID) return false }, }) } + b.logf("[v1] taildrive: built %d candidate remotes", len(driveRemotes)) return driveRemotes } diff --git a/ipn/ipnlocal/extension_host.go b/ipn/ipnlocal/extension_host.go index 125a2329447a3..0c4b1d933f724 100644 --- a/ipn/ipnlocal/extension_host.go +++ b/ipn/ipnlocal/extension_host.go @@ -90,6 +90,11 @@ type ExtensionHost struct { extByType sync.Map // reflect.Type -> ipnext.Extension + // hasPendingAuthReconfig tracks whether an AuthReconfig call + // has been enqueued but not yet executed. It avoids redundant + // reconfigs and prevents them from piling up in the workQueue. + hasPendingAuthReconfig atomic.Bool + // mu protects the following fields. // It must not be held when calling [LocalBackend] methods // or when invoking callbacks registered by extensions. @@ -124,6 +129,8 @@ type Backend interface { NodeBackend() ipnext.NodeBackend + authReconfig() + ipnext.SafeBackend } @@ -339,7 +346,7 @@ func (h *ExtensionHost) FindMatchingExtension(target any) bool { val := reflect.ValueOf(target) typ := val.Type() - if typ.Kind() != reflect.Ptr || val.IsNil() { + if typ.Kind() != reflect.Pointer || val.IsNil() { panic("ipnext: target must be a non-nil pointer") } targetType := typ.Elem() @@ -541,6 +548,22 @@ func (h *ExtensionHost) Shutdown() { h.shutdownOnce.Do(h.shutdown) } +// AuthReconfigAsync implements [ipnext.Host.AuthReconfigAsync]. +// Since execution uses the most recent state at execution time, +// multiple enqueued calls are redundant and are not enqueued. +func (h *ExtensionHost) AuthReconfigAsync() { + if h == nil { + return + } + if h.hasPendingAuthReconfig.Swap(true) { + return + } + h.enqueueBackendOperation(func(b Backend) { + h.hasPendingAuthReconfig.Store(false) + b.authReconfig() + }) +} + func (h *ExtensionHost) shutdown() { h.shuttingDown.Store(true) // Prevent any queued but not yet started operations from running, diff --git a/ipn/ipnlocal/extension_host_test.go b/ipn/ipnlocal/extension_host_test.go index 3bd302aeab93d..58955dc6c7a67 100644 --- a/ipn/ipnlocal/extension_host_test.go +++ b/ipn/ipnlocal/extension_host_test.go @@ -1010,9 +1010,8 @@ func TestNilExtensionHostMethodCall(t *testing.T) { t.Parallel() var h *ExtensionHost - typ := reflect.TypeOf(h) - for i := range typ.NumMethod() { - m := typ.Method(i) + typ := reflect.TypeFor[*ExtensionHost]() + for m := range typ.Methods() { if strings.HasSuffix(m.Name, "ForTest") { // Skip methods that are only for testing. continue @@ -1376,6 +1375,7 @@ func (b *testBackend) Sys() *tsd.System { func (b *testBackend) SendNotify(ipn.Notify) { panic("not implemented") } func (b *testBackend) NodeBackend() ipnext.NodeBackend { panic("not implemented") } func (b *testBackend) TailscaleVarRoot() string { panic("not implemented") } +func (b *testBackend) authReconfig() { panic("not implemented") } func (b *testBackend) SwitchToBestProfile(reason string) { b.mu.Lock() diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index d2d52ca422b10..242b31b4bdbf3 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -81,7 +81,6 @@ import ( "tailscale.com/types/opt" "tailscale.com/types/persist" "tailscale.com/types/preftype" - "tailscale.com/types/ptr" "tailscale.com/types/views" "tailscale.com/util/checkchange" "tailscale.com/util/clientmetric" @@ -153,6 +152,7 @@ type watchSession struct { owner ipnauth.Actor // or nil sessionID string cancel context.CancelFunc // to shut down the session + mask ipn.NotifyWatchOpt // watch options for this session } var ( @@ -411,6 +411,12 @@ type LocalBackend struct { // getCertForTest is used to retrieve TLS certificates in tests. // See [LocalBackend.ConfigureCertsForTest]. getCertForTest func(hostname string) (*TLSCertKeyPair, error) + + // existsPendingAuthReconfig tracks if a goroutine is waiting to + // acquire [LocalBackend]'s mutex inside of [LocalBackend.AuthReconfig]. + // It is used to prevent goroutines from piling up to do the same + // work of [LocalBackend.authReconfigLocked]. + existsPendingAuthReconfig atomic.Bool } // SetHardwareAttested enables hardware attestation key signatures in map @@ -534,6 +540,8 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo b.currentNodeAtomic.Store(nb) nb.ready() + e.SetPeerByIPPacketFunc(b.lookupPeerByIP) + if sys.InitialConfig != nil { if err := b.initPrefsFromConfig(sys.InitialConfig); err != nil { return nil, err @@ -577,7 +585,7 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo // Call our linkChange code once with the current state. // Following changes are triggered via the eventbus. - cd, err := netmon.NewChangeDelta(nil, b.interfaceState, false, false) + cd, err := netmon.NewChangeDelta(nil, b.interfaceState, 0, false) if err != nil { b.logf("[unexpected] setting initial netmon state failed: %v", err) } else { @@ -621,6 +629,7 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo } eventbus.SubscribeFunc(ec, b.onAppConnectorRouteUpdate) eventbus.SubscribeFunc(ec, b.onAppConnectorStoreRoutes) + eventbus.SubscribeFunc(ec, b.onHomeDERPUpdate) mConn.SetNetInfoCallback(b.setNetInfo) // TODO(tailscale/tailscale#17887): move to eventbus return b, nil @@ -652,6 +661,53 @@ func (b *LocalBackend) onAppConnectorStoreRoutes(ri appctype.RouteInfo) { } } +// testOnlyHomeDERPUpdate if non-nil is called after setting home DERP and +// writing netmap to disk. +var testOnlyHomeDERPUpdate func() + +func (b *LocalBackend) onHomeDERPUpdate(du magicsock.HomeDERPChanged) { + b.mu.Lock() + defer b.mu.Unlock() + + b.onHomeDERPUpdateLocked(du) + + if testOnlyHomeDERPUpdate != nil { + testOnlyHomeDERPUpdate() + } +} + +// onHomeDERPUpdateLocked considitonally updates the homeDERP for use in the +// netmap cache. +// If we switched our currentNode by switching profiles, we might be trying +// to update the homeDERP from another profile. If the old homeDERP does not +// match what we expect, don't swap the homeDERP. +// In practice, it is possible that one profile with a homeDERP of 0 (no-derp) +// got switched before setting any home DERP or that DERP IDs match across +// DERP maps. Since the risk of this happening is small and the consequences +// of this is is just a possible less optimal DERP until the next reSTUN, +// accept this possibility. +func (b *LocalBackend) onHomeDERPUpdateLocked(du magicsock.HomeDERPChanged) { + cn := b.currentNode() + + if cn == nil || cn.DERPMap() == nil || cn.DERPMap().Regions == nil { + return + } + + if _, ok := cn.DERPMap().Regions[du.New]; !ok { + return + } + + if !cn.homeDERP.CompareAndSwap(int64(du.Old), int64(du.New)) { + return + } + + // Persist the full netmap (including up-to-date Peers) to disk for + // fast restart. + if err := b.writeNetmapToDiskLocked(b.NetMapWithPeers()); err != nil { + b.logf("write netmap to cache: %v", err) + } +} + func (b *LocalBackend) Clock() tstime.Clock { return b.clock } func (b *LocalBackend) Sys() *tsd.System { return b.sys } @@ -958,8 +1014,6 @@ func (b *LocalBackend) setConfigLocked(conf *conffile.Config) error { return nil } -var assumeNetworkUpdateForTest = envknob.RegisterBool("TS_ASSUME_NETWORK_UP_FOR_TEST") - // pauseOrResumeControlClientLocked pauses b.cc if there is no network available // or if the LocalBackend is in Stopped state with a valid NetMap. In all other // cases, it unpauses it. It is a no-op if b.cc is nil. @@ -971,7 +1025,7 @@ func (b *LocalBackend) pauseOrResumeControlClientLocked() { return } networkUp := b.interfaceState.AnyInterfaceUp() - pauseForNetwork := (b.state == ipn.Stopped && b.NetMap() != nil) || (!networkUp && !testenv.InTest() && !assumeNetworkUpdateForTest()) + pauseForNetwork := (b.state == ipn.Stopped && b.NetMapNoPeers() != nil) || (!networkUp && !testenv.InTest() && !envknob.AssumeNetworkUp()) prefs := b.pm.CurrentPrefs() pauseForSyncPref := prefs.Valid() && prefs.Sync().EqualBool(false) @@ -1073,6 +1127,14 @@ func (b *LocalBackend) onHealthChange(change health.Change) { Health: state, }) + // Update control if IP forwarding state changed + _, broken := state.Warnings["ip-forwarding-off"] + b.mu.Lock() + if b.cc != nil { + b.cc.SetIPForwardingBroken(broken) + } + b.mu.Unlock() + if f, ok := hookCaptivePortalHealthChange.GetOk(); ok { f(b, state) } @@ -1466,6 +1528,7 @@ func profileFromView(v tailcfg.UserProfileView) tailcfg.UserProfile { LoginName: v.LoginName(), DisplayName: v.DisplayName(), ProfilePicURL: v.ProfilePicURL(), + Groups: v.Groups().AsSlice(), } } return tailcfg.UserProfile{} @@ -1558,6 +1621,16 @@ func (b *LocalBackend) PeerCaps(src netip.Addr) tailcfg.PeerCapMap { return b.currentNode().PeerCaps(src) } +// PeerByID returns the current full [tailcfg.Node] for the peer with the +// given NodeID, in O(1) time. It returns ok=false if no such peer is in +// the current netmap. +// +// It is intended for callers that need the latest state of a single peer +// without fetching the entire netmap. +func (b *LocalBackend) PeerByID(id tailcfg.NodeID) (n tailcfg.NodeView, ok bool) { + return b.currentNode().NodeByID(id) +} + func (b *LocalBackend) GetFilterForTest() *filter.Filter { testenv.AssertInTest() nb := b.currentNode() @@ -1738,7 +1811,7 @@ func (b *LocalBackend) setControlClientStatusLocked(c controlclient.Client, st c b.logf("Failed to save new controlclient state: %v", err) } - b.sendToLocked(ipn.Notify{Prefs: ptr.To(prefs.View())}, allClients) + b.sendToLocked(ipn.Notify{Prefs: new(prefs.View())}, allClients) } // initTKALocked is dependent on CurrentProfile.ID, which is initialized @@ -1808,7 +1881,24 @@ func (b *LocalBackend) setControlClientStatusLocked(c controlclient.Client, st c } b.e.SetNetworkMap(st.NetMap) - b.MagicConn().SetDERPMap(st.NetMap.DERPMap) + + var cachedHome int + if c == nil && st.NetMap.Cached && st.NetMap.SelfNode.Valid() { + cachedHome = st.NetMap.SelfNode.HomeDERP() + } + if cachedHome != 0 { + // Loading from a cached netmap (c == nil means no live control + // client). Pre-seed the home DERP from the cached self node so + // that the guard in maybeSetNearestDERP prevents changing the + // DERP home before we reconnect to the control plane. If the cache has + // nothing in it, skip this, and let the node pick a DERP itself. + b.MagicConn().SetDERPMapWithoutReSTUN(st.NetMap.DERPMap) + b.health.SetOutOfPollNetMap() + b.MagicConn().ForceSetNearestDERP(cachedHome) + } else { + b.MagicConn().SetDERPMap(st.NetMap.DERPMap) + } + b.MagicConn().SetOnlyTCP443(st.NetMap.HasCap(tailcfg.NodeAttrOnlyTCP443)) // Update our cached DERP map @@ -1817,7 +1907,15 @@ func (b *LocalBackend) setControlClientStatusLocked(c controlclient.Client, st c // Update the DERP map in the health package, which uses it for health notifications b.health.SetDERPMap(st.NetMap.DERPMap) - b.sendLocked(ipn.Notify{NetMap: st.NetMap}) + // Notify watchers that the self node may have changed. Reactive + // consumers (containerboot, kube agents, sniproxy, etc.) listen on + // this signal and re-fetch peers/DNS via the LocalAPI if they need + // more than self info. + var selfChange *tailcfg.Node + if st.NetMap.SelfNode.Valid() { + selfChange = st.NetMap.SelfNode.AsStruct() + } + b.sendLocked(ipn.Notify{NetMap: st.NetMap, SelfChange: selfChange}) // The error here is unimportant as is the result. This will recalculate the suggested exit node // cache the value and push any changes to the IPN bus. @@ -1843,6 +1941,18 @@ func (b *LocalBackend) setControlClientStatusLocked(c controlclient.Client, st c b.authReconfigLocked() } +func (b *LocalBackend) PatchDiscoKey(pub key.NodePublic, disco key.DiscoPublic) { + // PatchDiscoKey mirrors the implementation of [controlclient.patchDiscoKeyer]. + // It is implemented here to avoid the dependency edge to controlclient, but must be kept + // in sync with the original implementation. + type patchDiscoKeyer interface { + PatchDiscoKey(key.NodePublic, key.DiscoPublic) + } + if e, ok := b.e.(patchDiscoKeyer); ok { + e.PatchDiscoKey(pub, disco) + } +} + type preferencePolicyInfo struct { key pkey.Key get func(ipn.PrefsView) bool @@ -2140,10 +2250,20 @@ func (b *LocalBackend) UpdateNetmapDelta(muts []netmap.NodeMutation) (handled bo b.suggestExitNodeLocked() } - if cn.NetMap() != nil && mutationsAreWorthyOfTellingIPNBus(muts) { + if cn.NetMap() == nil { + b.logf("[unexpected] got node mutations but netmap is nil; mutations not applied") + return true + } - nm := cn.netMapWithPeers() - notify = &ipn.Notify{NetMap: nm} + if mutationsAreWorthyOfTellingIPNBus(muts) { + // The notifier will strip the netmap based on the watchOpts mask if the watcher + // has indicated it can handle PeerChanges. + notify = &ipn.Notify{NetMap: cn.netMapWithPeers()} + if peerChanges, ok := ipnBusPeerChangesFromNodeMutations(muts); ok { + notify.PeerChanges = peerChanges + } else { + b.logf("[unexpected] got mutations worthy of telling IPN bus but failed to convert to peer changes") + } } else if testenv.InTest() { // In tests, send an empty Notify as a wake-up so end-to-end // integration tests in another repo can check on the status of @@ -2191,6 +2311,39 @@ func mutationsAreWorthyOfRecalculatingSuggestedExitNode(muts []netmap.NodeMutati return false } +// ipnBusPeerChangesFromNodeMutations converts a slice of NodeMutations to a slice of +// *tailcfg.PeerChange for use in ipn.Notify.PeerChanges. +// Multiple mutations to the same node are merged into a single PeerChange. +// If we encounter any mutations that we cannot convert to a PeerChange, we return (nil, false) +// to indicate that the caller should send a Notify with the full netmap instead of +// trying to send granular peer changes. +func ipnBusPeerChangesFromNodeMutations(muts []netmap.NodeMutation) ([]*tailcfg.PeerChange, bool) { + byID := map[tailcfg.NodeID]*tailcfg.PeerChange{} + var ordered []*tailcfg.PeerChange + for _, m := range muts { + nid := m.NodeIDBeingMutated() + pc := byID[nid] + if pc == nil { + pc = &tailcfg.PeerChange{NodeID: nid} + byID[nid] = pc + ordered = append(ordered, pc) + } + switch v := m.(type) { + case netmap.NodeMutationOnline: + pc.Online = &v.Online + case netmap.NodeMutationLastSeen: + pc.LastSeen = &v.LastSeen + case netmap.NodeMutationDERPHome: + pc.DERPRegion = v.DERPRegion + case netmap.NodeMutationEndpoints: + pc.Endpoints = v.Endpoints + default: + return nil, false + } + } + return ordered, true +} + // mutationsAreWorthyOfTellingIPNBus reports whether any mutation type in muts is // worthy of spamming the IPN bus (the Windows & Mac GUIs, basically) to tell them // about the update. @@ -2403,6 +2556,14 @@ func (b *LocalBackend) PeersForTest() []tailcfg.NodeView { return b.currentNode().PeersForTest() } +// AwaitNodeKeyForTest returns a channel that is closed once a peer with the +// given node key first appears in the current netmap. If the peer is already +// present, the returned channel is already closed. See +// [nodeBackend.AwaitNodeKeyForTest]. +func (b *LocalBackend) AwaitNodeKeyForTest(k key.NodePublic) <-chan struct{} { + return b.currentNode().AwaitNodeKeyForTest(k) +} + func (b *LocalBackend) getNewControlClientFuncLocked() clientGen { if b.ccGen == nil { // Initialize it rather than just returning the @@ -2562,6 +2723,7 @@ func (b *LocalBackend) startLocked(opts ipn.Options) error { // Reset the always-on override whenever Start is called. b.resetAlwaysOnOverrideLocked() b.setAtomicValuesFromPrefsLocked(prefs) + b.updateNoSNATExitNodeWarning(prefs) wantRunning := prefs.WantRunning() if wantRunning { @@ -2584,7 +2746,21 @@ func (b *LocalBackend) startLocked(opts ipn.Options) error { persistv = new(persist.Persist) } - if envknob.Bool("TS_USE_CACHED_NETMAP") { + // At this point we do not yet know whether we are meant to cache netmaps by + // policy (as we have not yet spoken to the control plane). + // + // However, since we do not create or update a netmap cache unless we observe the + // [tailcfg.NodeAttrCachedNetworkMaps] capability, we can use the presence + // of the cached netmap as a signal that we were expected to do so as of the + // last time we updated the cache. + // + // If the policy has (since) changed, a subsequent network map from the control + // plane may remove the attribute, at which point we will drop the cache. + // + // As of 2026-03-25 we require the envknob set to read a cached netmap, with + // the envknob defaulted to true so we can use it as a safety override + // during rollout. + if envknob.BoolDefaultTrue("TS_USE_CACHED_NETMAP") { if nm, ok := b.loadDiskCacheLocked(); ok { logf("loaded netmap from disk cache; %d peers", len(nm.Peers)) b.setControlClientStatusLocked(nil, controlclient.Status{ @@ -2623,6 +2799,7 @@ func (b *LocalBackend) startLocked(opts ipn.Options) error { DiscoPublicKey: discoPublic, DebugFlags: b.controlDebugFlags(), HealthTracker: b.health, + ExtraRootCAs: b.sys.ExtraRootCAs, PolicyClient: b.sys.PolicyClientOrDefault(), Pinger: b, PopBrowserURL: b.tellClientToBrowseToURL, @@ -2634,10 +2811,6 @@ func (b *LocalBackend) startLocked(opts ipn.Options) error { Shutdown: ccShutdown, Bus: b.sys.Bus.Get(), StartPaused: prefs.Sync().EqualBool(false), - - // Don't warn about broken Linux IP forwarding when - // netstack is being used. - SkipIPForwardingCheck: b.sys.IsNetstackRouter(), }) if err != nil { return err @@ -2687,7 +2860,7 @@ func (b *LocalBackend) startLocked(opts ipn.Options) error { // Without this, the state machine transitions to "NeedsLogin" implying // that user interaction is required, which is not the case and can // regress tsnet.Server restarts. - cc.Login(controlclient.LoginDefault) + cc.Login(b.loginFlags) } b.stateMachineLocked() @@ -3138,19 +3311,19 @@ func (b *LocalBackend) WatchNotificationsAs(ctx context.Context, actor ipnauth.A b.mu.Lock() - const initialBits = ipn.NotifyInitialState | ipn.NotifyInitialPrefs | ipn.NotifyInitialNetMap | ipn.NotifyInitialDriveShares | ipn.NotifyInitialSuggestedExitNode + const initialBits = ipn.NotifyInitialState | ipn.NotifyInitialPrefs | ipn.NotifyInitialNetMap | ipn.NotifyInitialDriveShares | ipn.NotifyInitialSuggestedExitNode | ipn.NotifyInitialClientVersion if mask&initialBits != 0 { cn := b.currentNode() ini = &ipn.Notify{Version: version.Long()} if mask&ipn.NotifyInitialState != 0 { ini.SessionID = sessionID - ini.State = ptr.To(b.state) + ini.State = new(b.state) if b.state == ipn.NeedsLogin && b.authURL != "" { - ini.BrowseToURL = ptr.To(b.authURL) + ini.BrowseToURL = new(b.authURL) } } if mask&ipn.NotifyInitialPrefs != 0 { - ini.Prefs = ptr.To(b.sanitizedPrefsLocked()) + ini.Prefs = new(b.sanitizedPrefsLocked()) } if mask&ipn.NotifyInitialNetMap != 0 { ini.NetMap = cn.NetMap() @@ -3166,6 +3339,11 @@ func (b *LocalBackend) WatchNotificationsAs(ctx context.Context, actor ipnauth.A ini.SuggestedExitNode = &en.ID } } + if mask&ipn.NotifyInitialClientVersion != 0 { + if prefs := b.pm.CurrentPrefs(); prefs.Valid() && prefs.AutoUpdate().Check { + ini.ClientVersion = b.lastClientVersion + } + } } ctx, cancel := context.WithCancel(ctx) @@ -3176,6 +3354,7 @@ func (b *LocalBackend) WatchNotificationsAs(ctx context.Context, actor ipnauth.A owner: actor, sessionID: sessionID, cancel: cancel, + mask: mask, } mak.Set(&b.notifyWatchers, sessionID, session) b.mu.Unlock() @@ -3402,7 +3581,7 @@ func (b *LocalBackend) sendTo(n ipn.Notify, recipient notificationTarget) { // sendToLocked is like [LocalBackend.sendTo], but assumes b.mu is already held. func (b *LocalBackend) sendToLocked(n ipn.Notify, recipient notificationTarget) { if n.Prefs != nil { - n.Prefs = ptr.To(stripKeysFromPrefs(*n.Prefs)) + n.Prefs = new(stripKeysFromPrefs(*n.Prefs)) } if n.Version == "" { n.Version = version.Long() @@ -3413,13 +3592,27 @@ func (b *LocalBackend) sendToLocked(n ipn.Notify, recipient notificationTarget) } for _, sess := range b.notifyWatchers { - if recipient.match(sess.owner) { - select { - case sess.ch <- &n: - default: - // Drop the notification if the channel is full. + if !recipient.match(sess.owner) { + continue + } + nOut := &n + if n.PeerChanges != nil { + // Take a shallow copy of n so we can elide the PeerChanges or the Netmap + // based on the session's mask. + nOut = new(n) + if sess.mask&ipn.NotifyPeerChanges != 0 { + // Skip the full Netmap + nOut.NetMap = nil + } else { + // Skip the PeerChanges + nOut.PeerChanges = nil } } + select { + case sess.ch <- nOut: + default: + // Drop the notification if the channel is full. + } } } @@ -3468,12 +3661,11 @@ func (b *LocalBackend) setAuthURLLocked(url string) { // // b.mu must be held. func (b *LocalBackend) popBrowserAuthNowLocked(url string, keyExpired bool, recipient ipnauth.Actor) { - b.logf("popBrowserAuthNow(%q): url=%v, key-expired=%v, seamless-key-renewal=%v", maybeUsernameOf(recipient), url != "", keyExpired, b.seamlessRenewalEnabled()) + b.logf("popBrowserAuthNow(%q): url=%v, key-expired=%v", maybeUsernameOf(recipient), url != "", keyExpired) - // Deconfigure the local network data plane if: - // - seamless key renewal is not enabled; - // - key is expired (in which case tailnet connectivity is down anyway). - if !b.seamlessRenewalEnabled() || keyExpired { + // Deconfigure the local network data plane if the key is expired + // (in which case tailnet connectivity is down anyway). + if keyExpired { b.blockEngineUpdatesLocked(true) b.stopEngineAndWaitLocked() @@ -3546,10 +3738,13 @@ func (b *LocalBackend) tellRecipientToBrowseToURLLocked(url string, recipient no // a non-nil ClientVersion message. func (b *LocalBackend) onClientVersion(v *tailcfg.ClientVersion) { b.mu.Lock() + defer b.mu.Unlock() b.lastClientVersion = v b.health.SetLatestVersion(v) - b.mu.Unlock() - b.send(ipn.Notify{ClientVersion: v}) + prefs := b.pm.CurrentPrefs() + if prefs.Valid() && prefs.AutoUpdate().Check { + b.sendLocked(ipn.Notify{ClientVersion: v}) + } } func (b *LocalBackend) onTailnetDefaultAutoUpdate(au bool) { @@ -3693,12 +3888,7 @@ func generateInterceptTCPPortFunc(ports []uint16) func(uint16) bool { f = func(p uint16) bool { return m[p] } } else { f = func(p uint16) bool { - for _, x := range ports { - if p == x { - return true - } - } - return false + return slices.Contains(ports, p) } } } @@ -3892,7 +4082,8 @@ func (b *LocalBackend) pingPeerAPI(ctx context.Context, ip netip.Addr) (peer tai var zero tailcfg.NodeView ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - nm := b.NetMap() + // PeerByTailscaleIP needs an up-to-date Peers slice. + nm := b.NetMapWithPeers() if nm == nil { return zero, "", errors.New("no netmap") } @@ -4119,6 +4310,8 @@ func (b *LocalBackend) CurrentUserForTest() (ipn.WindowsUserID, ipnauth.Actor) { return b.pm.CurrentUserID(), b.currentUser } +// CheckPrefs validates the provided user modifiable settings for correctness +// and returns an error if they are invalid for the current backend. func (b *LocalBackend) CheckPrefs(p *ipn.Prefs) error { b.mu.Lock() defer b.mu.Unlock() @@ -4158,6 +4351,9 @@ func (b *LocalBackend) checkPrefsLocked(p *ipn.Prefs) error { if err := b.checkAutoUpdatePrefsLocked(p); err != nil { errs = append(errs, err) } + if err := checkAdvertiseRoutes(p); err != nil { + errs = append(errs, err) + } return errors.Join(errs...) } @@ -4255,6 +4451,18 @@ func (b *LocalBackend) checkAutoUpdatePrefsLocked(p *ipn.Prefs) error { return nil } +// checkAdvertiseRoutes validates that all advertised routes have +// properly masked prefixes (no non-address bits set). +func checkAdvertiseRoutes(p *ipn.Prefs) error { + var errs []error + for _, route := range p.AdvertiseRoutes { + if route != route.Masked() { + errs = append(errs, fmt.Errorf("route %s has non-address bits set; expected %s", route, route.Masked())) + } + } + return errors.Join(errs...) +} + // SetUseExitNodeEnabled turns on or off the most recently selected exit node. // // On success, it returns the resulting prefs (or current prefs, in the case of no change). @@ -4420,7 +4628,7 @@ func (b *LocalBackend) changeDisablesExitNodeLocked(prefs ipn.PrefsView, change // First, apply the adjustments to a copy of the changes, // e.g., clear AutoExitNode if ExitNodeID is set. - tmpChange := ptr.To(*change) + tmpChange := new(*change) tmpChange.Prefs = *change.Prefs.Clone() b.adjustEditPrefsLocked(prefs, tmpChange) @@ -4700,7 +4908,7 @@ func (b *LocalBackend) setPrefsLocked(newp *ipn.Prefs) ipn.PrefsView { if !oldp.Persist().Valid() { b.logf("active login: %s", newLoginName) } else { - oldLoginName := oldp.Persist().UserProfile().LoginName + oldLoginName := oldp.Persist().UserProfile().LoginName() if oldLoginName != newLoginName { b.logf("active login: %q (changed from %q)", newLoginName, oldLoginName) } @@ -4722,6 +4930,7 @@ func (b *LocalBackend) setPrefsLocked(newp *ipn.Prefs) ipn.PrefsView { b.pauseOrResumeControlClientLocked() // for prefs.Sync changes b.updateWarnSync(prefs) + b.updateNoSNATExitNodeWarning(prefs) if oldp.ShieldsUp() != newp.ShieldsUp || hostInfoChanged { b.doSetHostinfoFilterServicesLocked() @@ -4733,7 +4942,7 @@ func (b *LocalBackend) setPrefsLocked(newp *ipn.Prefs) ipn.PrefsView { if !oldp.WantRunning() && newp.WantRunning && cc != nil { b.logf("transitioning to running; doing Login...") - cc.Login(controlclient.LoginDefault) + cc.Login(b.loginFlags) } if oldp.WantRunning() != newp.WantRunning { @@ -4742,6 +4951,12 @@ func (b *LocalBackend) setPrefsLocked(newp *ipn.Prefs) ipn.PrefsView { b.authReconfigLocked() } + if newp.AutoUpdate.Check && !oldp.AutoUpdate().Check { + if cv := b.lastClientVersion; cv != nil { + b.sendLocked(ipn.Notify{ClientVersion: cv}) + } + } + b.sendLocked(ipn.Notify{Prefs: &prefs}) return prefs } @@ -4778,7 +4993,7 @@ func (b *LocalBackend) handlePeerAPIConn(remote, local netip.AddrPort, c net.Con } func (b *LocalBackend) isLocalIP(ip netip.Addr) bool { - nm := b.NetMap() + nm := b.NetMapNoPeers() return nm != nil && views.SliceContains(nm.GetAddresses(), netip.PrefixFrom(ip, ip.BitLen())) } @@ -4930,10 +5145,67 @@ func extractPeerAPIPorts(services []tailcfg.Service) portPair { // NetMap returns the latest cached network map received from // controlclient, or nil if no network map was received yet. +// +// Deprecated: callers should declare their needs explicitly by calling +// either [LocalBackend.NetMapNoPeers] (cheap; for code that reads +// non-Peers fields like SelfNode, DNS, PacketFilter, capabilities) or +// [LocalBackend.NetMapWithPeers] (currently the same; will be made to +// return an up-to-date Peers slice in a follow-up change, at the cost of +// O(N) work per call). NetMap will eventually be removed. func (b *LocalBackend) NetMap() *netmap.NetworkMap { return b.currentNode().NetMap() } +// NetMapNoPeers returns the latest cached network map received from +// controlclient WITHOUT a freshly-built Peers slice. +// +// On a tailnet with frequent peer churn the cached netmap's Peers slice +// can be stale relative to the live per-node-backend peers map; non-Peers +// fields (SelfNode, DNS, PacketFilter, capabilities, ...) are always +// current. Use this for any caller that does not need to iterate Peers, +// since it's O(1) regardless of tailnet size. +// +// Returns nil if no network map has been received yet. +func (b *LocalBackend) NetMapNoPeers() *netmap.NetworkMap { + return b.currentNode().NetMap() +} + +// NetMapWithPeers returns the latest network map with the Peers slice +// populated. +// +// Currently this is the same as [LocalBackend.NetMapNoPeers]: the cached +// netmap's Peers slice may be stale relative to the live per-node-backend +// peers map. A follow-up change will switch this method to return a +// freshly-built netmap with up-to-date Peers, at O(N) cost per call. +// Callers that genuinely need the up-to-date peer set should use this +// method (and document why) so the upcoming change reaches them. +// +// Returns nil if no network map has been received yet. +func (b *LocalBackend) NetMapWithPeers() *netmap.NetworkMap { + return b.currentNode().NetMap() +} + +// lookupPeerByIP returns the node public key for the peer that owns the +// given IP address. It is the fast path for [Engine.SetPeerByIPPacketFunc], +// handling exact-IP matches against node addresses; subnet routes and exit +// nodes are handled by a BART-based fallback in userspaceEngine that uses +// the wireguard-filtered peer list (see lastCfgFull). +// +// It is called by wireguard-go on every outbound packet (not cached), so +// it must be fast. +func (b *LocalBackend) lookupPeerByIP(ip netip.Addr) (key.NodePublic, bool) { + nb := b.currentNode() + nid, ok := nb.NodeByAddr(ip) + if !ok { + return key.NodePublic{}, false + } + peer, ok := nb.NodeByID(nid) + if !ok { + return key.NodePublic{}, false + } + return peer.Key(), true +} + func (b *LocalBackend) isEngineBlocked() bool { b.mu.Lock() defer b.mu.Unlock() @@ -5057,10 +5329,26 @@ func (b *LocalBackend) readvertiseAppConnectorRoutes() { // authReconfig pushes a new configuration into wgengine, if engine // updates are not currently blocked, based on the cached netmap and -// user prefs. +// user prefs. Callers may experience an early return with no work +// done if another goroutine is waiting for the mutex inside this method. +// If there is no other goroutine waiting, the calling goroutine will +// proceed to reconfiguration after acquiring the mutex. + +// Reconfiguration may run asynchronously and may not complete +// before the call returns. func (b *LocalBackend) authReconfig() { + // If there's already a pending auth reconfig from another + // goroutine, exit early. If not, this goroutine becomes the pending. + if b.existsPendingAuthReconfig.Swap(true) { + return + } + b.mu.Lock() defer b.mu.Unlock() + + // Allow another goroutine to become pending. + b.existsPendingAuthReconfig.Store(false) + b.authReconfigLocked() } @@ -5068,7 +5356,6 @@ func (b *LocalBackend) authReconfig() { // // b.mu must be held. func (b *LocalBackend) authReconfigLocked() { - if b.shutdownCalled { b.logf("[v1] authReconfig: skipping because in shutdown") return @@ -5135,6 +5422,16 @@ func (b *LocalBackend) authReconfigLocked() { oneCGNATRoute := shouldUseOneCGNATRoute(b.logf, b.sys.NetMon.Get(), b.sys.ControlKnobs(), version.OS()) rcfg := b.routerConfigLocked(cfg, prefs, nm, oneCGNATRoute) + // Add these extra Allowed IPs after router configuration, because the expected + // extension (features/conn25), does not want these routes installed on the OS. + // See also [Hooks.ExtraWireGuardAllowedIPs]. + if extraAllowedIPsFn, ok := b.extHost.hooks.ExtraWireGuardAllowedIPs.GetOk(); ok { + for i := range cfg.Peers { + extras := extraAllowedIPsFn(cfg.Peers[i].PublicKey) + cfg.Peers[i].AllowedIPs = extras.AppendTo(cfg.Peers[i].AllowedIPs) + } + } + err = b.e.Reconfig(cfg, rcfg, dcfg) if err == wgengine.ErrNoChanges { return @@ -5337,7 +5634,7 @@ func (b *LocalBackend) initPeerAPIListenerLocked() { cn := b.currentNode() nm := cn.NetMap() if nm == nil { - // We're called from authReconfig which checks that + // We're called from authReconfigLocked which checks that // netMap is non-nil, but if a concurrent Logout, // ResetForClientDisconnect, or Start happens when its // mutex was released, the netMap could be @@ -5532,13 +5829,14 @@ func (b *LocalBackend) routerConfigLocked(cfg *wgcfg.Config, prefs ipn.PrefsView } rs := &router.Config{ - LocalAddrs: unmapIPPrefixes(cfg.Addresses), - SubnetRoutes: unmapIPPrefixes(prefs.AdvertiseRoutes().AsSlice()), - SNATSubnetRoutes: !prefs.NoSNAT(), - StatefulFiltering: doStatefulFiltering, - NetfilterMode: prefs.NetfilterMode(), - Routes: peerRoutes(b.logf, cfg.Peers, singleRouteThreshold, prefs.RouteAll()), - NetfilterKind: netfilterKind, + LocalAddrs: unmapIPPrefixes(cfg.Addresses), + SubnetRoutes: unmapIPPrefixes(prefs.AdvertiseRoutes().AsSlice()), + SNATSubnetRoutes: !prefs.NoSNAT(), + StatefulFiltering: doStatefulFiltering, + NetfilterMode: prefs.NetfilterMode(), + Routes: peerRoutes(b.logf, cfg.Peers, singleRouteThreshold, prefs.RouteAll()), + NetfilterKind: netfilterKind, + RemoveCGNATDropRule: nm.HasCap(tailcfg.NodeAttrDisableLinuxCGNATDropRule), } if buildfeatures.HasSynology && distro.Get() == distro.Synology { @@ -5612,6 +5910,11 @@ func (b *LocalBackend) routerConfigLocked(cfg *wgcfg.Config, prefs ipn.PrefsView } } + // Get any extra Routes an extension may want installed. + if extensionRoutesFx, ok := b.extHost.hooks.ExtraRouterConfigRoutes.GetOk(); ok { + rs.Routes = extensionRoutesFx().AppendTo(rs.Routes) + } + return rs } @@ -5640,6 +5943,22 @@ func (b *LocalBackend) applyPrefsToHostinfoLocked(hi *tailcfg.Hostinfo, prefs ip if buildfeatures.HasAdvertiseRoutes { b.metrics.advertisedRoutes.Set(float64(tsaddr.WithoutExitRoute(prefs.AdvertiseRoutes()).Len())) + + // Set up IP forwarding check when routes change + if len(hi.RoutableIPs) > 0 && b.NetMon() != nil && !b.sys.IsNetstackRouter() { + routes := hi.RoutableIPs + netMon := b.NetMon() + b.health.SetIPForwardingCheck(func() bool { + warn, err := netutil.CheckIPForwarding(routes, netMon.InterfaceState()) + if err != nil { + metricIPForwardingCheckError.Add(1) + return false // don't want false positives + } + return warn != nil // true if broken + }) + } else { + b.health.SetIPForwardingCheck(nil) + } } var sshHostKeys []string @@ -5647,10 +5966,12 @@ func (b *LocalBackend) applyPrefsToHostinfoLocked(hi *tailcfg.Hostinfo, prefs ip // TODO(bradfitz): this is called with b.mu held. Not ideal. // If the filesystem gets wedged or something we could block for // a long time. But probably fine. - var err error - sshHostKeys, err = b.getSSHHostKeyPublicStrings() - if err != nil { - b.logf("warning: unable to get SSH host keys, SSH will appear as disabled for this node: %v", err) + if f, ok := feature.HookGetSSHHostKeyPublicStrings.GetOk(); ok { + var err error + sshHostKeys, err = f(b.TailscaleVarRoot(), b.logf) + if err != nil { + b.logf("warning: unable to get SSH host keys, SSH will appear as disabled for this node: %v", err) + } } } hi.SSH_HostKeys = sshHostKeys @@ -5763,9 +6084,9 @@ func (b *LocalBackend) enterStateLocked(newState ipn.State) { switch newState { case ipn.NeedsLogin: feature.SystemdStatus("Needs login: %s", authURL) - // always block updates on NeedsLogin even if seamless renewal is enabled, - // to prevent calls to authReconfig from reconfiguring the engine when our - // key has expired and we're waiting to authenticate to use the new key. + // always block updates on NeedsLogin, to prevent calls to authReconfigLocked + // from reconfiguring the engine when our key has expired and we're waiting + // to authenticate to use the new key. b.blockEngineUpdatesLocked(true) fallthrough case ipn.Stopped, ipn.NoState: @@ -6190,7 +6511,7 @@ func (b *LocalBackend) resolveExitNodeLocked() (changed bool) { b.goTracker.Go(b.doSetHostinfoFilterServices) } - b.sendToLocked(ipn.Notify{Prefs: ptr.To(prefs.View())}, allClients) + b.sendToLocked(ipn.Notify{Prefs: new(prefs.View())}, allClients) return true } @@ -6237,6 +6558,23 @@ func (b *LocalBackend) resolveExitNodeInPrefsLocked(prefs *ipn.Prefs) (changed b // received nm. If nm is nil, it resets all configuration as though // Tailscale is turned off. func (b *LocalBackend) setNetMapLocked(nm *netmap.NetworkMap) { + if buildfeatures.HasCacheNetMap { + // As a defensive measure, if something triggers a panic when we are + // installing a network map, make an effort to discard any cached netmaps. + // This helps avert the possibility that a restart after panic will stick in + // a cycle. Importantly, we do not attempt to swallow or handle the panic, + // since that indicates a real bug. + // + // See https://github.com/tailscale/tailscale/issues/12639 + defer func() { + if p := recover(); p != nil { + b.logf("WARNING: Panic while installing netmap; discardng caches") + b.discardDiskCacheLocked() + panic(p) // propagate + } + }() + } + oldSelf := b.currentNode().NetMap().SelfNodeOrZero() b.dialer.SetNetMap(nm) @@ -6246,11 +6584,6 @@ func (b *LocalBackend) setNetMapLocked(nm *netmap.NetworkMap) { var login string if nm != nil { login = cmp.Or(profileFromView(nm.UserProfiles[nm.User()]).LoginName, "") - if envknob.Bool("TS_USE_CACHED_NETMAP") { - if err := b.writeNetmapToDiskLocked(nm); err != nil { - b.logf("write netmap to cache: %v", err) - } - } } b.currentNode().SetNetMap(nm) if ms, ok := b.sys.MagicSock.GetOK(); ok { @@ -6298,6 +6631,9 @@ func (b *LocalBackend) setNetMapLocked(nm *netmap.NetworkMap) { // See the netns package for documentation on what these capability do. netns.SetBindToInterfaceByRoute(b.logf, nm.HasCap(tailcfg.CapabilityBindToInterfaceByRoute)) + if runtime.GOOS == "android" { + netns.SetDisableAndroidBindToActiveNetwork(b.logf, nm.HasCap(tailcfg.NodeAttrDisableAndroidBindToActiveNetwork)) + } netns.SetDisableBindConnToInterface(b.logf, nm.HasCap(tailcfg.CapabilityDebugDisableBindConnToInterface)) netns.SetDisableBindConnToInterfaceAppleExt(b.logf, nm.HasCap(tailcfg.CapabilityDebugDisableBindConnToInterfaceAppleExt)) @@ -6341,6 +6677,29 @@ func (b *LocalBackend) setNetMapLocked(nm *netmap.NetworkMap) { f(b, nm) } } + + // Reaching here, we have successfully applied a new network map, and must + // now (if configured) update the cache. We do this after application to + // reduce the chance we will cache a QoD netmap. + // + // As of 2026-03-25 we require the envknob AND the node attribute to use + // a netmap cache, with the envknob defaulted to true so we can use it as + // a safety override during rollout. + // + // We treat the envknob being false as identical to disabling the feature + // by policy, and clean up the cache on that basis. That ensures we will + // not wind up in a situation where we have a stale cached netmap that is + // not being updated (because of the envknob) and could be read back when + // the node starts up. + if nm != nil { + if b.currentNode().SelfHasCap(tailcfg.NodeAttrCacheNetworkMaps) && envknob.BoolDefaultTrue("TS_USE_CACHED_NETMAP") { + if err := b.writeNetmapToDiskLocked(nm); err != nil { + b.logf("write netmap to cache: %v", err) + } + } else { + b.discardDiskCacheLocked() + } + } } var hookSetNetMapLockedDrive feature.Hook[func(*LocalBackend, *netmap.NetworkMap)] @@ -6440,9 +6799,9 @@ func (b *LocalBackend) maybeSentHostinfoIfChangedLocked(prefs ipn.PrefsView) { } } -// operatorUserName returns the current pref's OperatorUser's name, or the +// OperatorUserName returns the current pref's OperatorUser's name, or the // empty string if none. -func (b *LocalBackend) operatorUserName() string { +func (b *LocalBackend) OperatorUserName() string { b.mu.Lock() defer b.mu.Unlock() prefs := b.pm.CurrentPrefs() @@ -6455,7 +6814,7 @@ func (b *LocalBackend) operatorUserName() string { // OperatorUserID returns the current pref's OperatorUser's ID (in // os/user.User.Uid string form), or the empty string if none. func (b *LocalBackend) OperatorUserID() string { - opUserName := b.operatorUserName() + opUserName := b.OperatorUserName() if opUserName == "" { return "" } @@ -6681,7 +7040,7 @@ func (b *LocalBackend) AppConnector() *appc.AppConnector { func (b *LocalBackend) allowExitNodeDNSProxyToServeName(name string) bool { b.mu.Lock() defer b.mu.Unlock() - nm := b.NetMap() + nm := b.NetMapNoPeers() if nm == nil { return false } @@ -6825,11 +7184,28 @@ func (b *LocalBackend) DebugRotateDiscoKey() error { b.mu.Lock() cc := b.cc + wantRunning := b.pm.CurrentPrefs().WantRunning() b.mu.Unlock() if cc != nil { cc.SetDiscoPublicKey(newDiscoKey) } + // Bounce WantRunning to fully reset wireguard-go state for all peers. + if wantRunning { + if _, err := b.EditPrefs(&ipn.MaskedPrefs{ + Prefs: ipn.Prefs{WantRunning: false}, + WantRunningSet: true, + }); err != nil { + return err + } + if _, err := b.EditPrefs(&ipn.MaskedPrefs{ + Prefs: ipn.Prefs{WantRunning: true}, + WantRunningSet: true, + }); err != nil { + return err + } + } + return nil } @@ -6837,6 +7213,25 @@ func (b *LocalBackend) DebugPeerRelayServers() set.Set[netip.Addr] { return b.MagicConn().PeerRelays() } +// DebugPeerDiscoKeys returns the disco public keys this node has learned for +// each of its peers from the most recent network map. Intended for tests +// (the production [ipnstate.PeerStatus] purposefully does not surface disco +// keys; surfacing them via the [ipnstate.Status] API would also pollute +// every PeerStatus consumer with a non-comparable struct field). +func (b *LocalBackend) DebugPeerDiscoKeys() map[key.NodePublic]key.DiscoPublic { + nm := b.currentNode().NetMap() + if nm == nil { + return nil + } + m := make(map[key.NodePublic]key.DiscoPublic, len(nm.Peers)) + for _, p := range nm.Peers { + if dk := p.DiscoKey(); !dk.IsZero() { + m[p.Key()] = dk + } + } + return m +} + // ControlKnobs returns the node's control knobs. func (b *LocalBackend) ControlKnobs() *controlknobs.Knobs { return b.sys.ControlKnobs() @@ -6903,6 +7298,18 @@ var warnSSHSELinuxWarnable = health.Register(&health.Warnable{ Text: health.StaticMessage("SELinux is enabled; Tailscale SSH may not work. See https://tailscale.com/s/ssh-selinux"), }) +// warnNoSNATWithExitNode is a warnable for when a node is advertising as an +// exit node but has SNAT disabled. In this configuration internet-bound traffic +// from peers using this exit node will not be masqueraded to the node's own +// source IP, so return packets cannot be routed back, causing the exit node to +// not work as expected. +var warnNoSNATWithExitNode = health.Register(&health.Warnable{ + Code: "nosnat-with-advertised-exit-node", + Title: "Exit node advertising may not work correctly", + Severity: health.SeverityMedium, + Text: health.StaticMessage("snat-subnet-routes is disabled while advertising as an exit node; internet traffic through this exit node may not work as expected"), +}) + func (b *LocalBackend) updateSELinuxHealthWarning() { if hostinfo.IsSELinuxEnforcing() { b.health.SetUnhealthy(warnSSHSELinuxWarnable, nil) @@ -6919,6 +7326,17 @@ func (b *LocalBackend) updateWarnSync(prefs ipn.PrefsView) { } } +func (b *LocalBackend) updateNoSNATExitNodeWarning(prefs ipn.PrefsView) { + if !buildfeatures.HasAdvertiseExitNode { + return + } + if prefs.NoSNAT() && prefs.AdvertisesExitNode() { + b.health.SetUnhealthy(warnNoSNATWithExitNode, nil) + } else { + b.health.SetHealthy(warnNoSNATWithExitNode) + } +} + func (b *LocalBackend) handleSSHConn(c net.Conn) (err error) { s, err := b.sshServerOrInit() if err != nil { @@ -7235,6 +7653,9 @@ func (b *LocalBackend) AdvertiseRoute(ipps ...netip.Prefix) error { var newRoutes []netip.Prefix for _, ipp := range ipps { + if ipp != ipp.Masked() { + return fmt.Errorf("route %s has non-address bits set; expected %s", ipp, ipp.Masked()) + } if !allowedAutoRoute(ipp) { continue } @@ -7359,14 +7780,6 @@ func (b *LocalBackend) ReadRouteInfo() (*appctype.RouteInfo, error) { return b.readRouteInfoLocked() } -// seamlessRenewalEnabled reports whether seamless key renewals are enabled. -// -// As of 2025-09-11, this is the default behaviour unless nodes receive -// [tailcfg.NodeAttrDisableSeamlessKeyRenewal] in their netmap. -func (b *LocalBackend) seamlessRenewalEnabled() bool { - return b.ControlKnobs().SeamlessKeyRenewal.Load() -} - var ( disallowedAddrs = []netip.Addr{ netip.MustParseAddr("::1"), @@ -7383,10 +7796,8 @@ var ( // allowedAutoRoute determines if the route being added via AdvertiseRoute (the app connector featuge) should be allowed. func allowedAutoRoute(ipp netip.Prefix) bool { // Note: blocking the addrs for globals, not solely the prefixes. - for _, addr := range disallowedAddrs { - if ipp.Addr() == addr { - return false - } + if slices.Contains(disallowedAddrs, ipp.Addr()) { + return false } for _, pfx := range disallowedRanges { if pfx.Overlaps(ipp) { @@ -7915,7 +8326,8 @@ func maybeUsernameOf(actor ipnauth.Actor) string { } var ( - metricCurrentWatchIPNBus = clientmetric.NewGauge("localbackend_current_watch_ipn_bus") + metricCurrentWatchIPNBus = clientmetric.NewGauge("localbackend_current_watch_ipn_bus") + metricIPForwardingCheckError = clientmetric.NewCounter("localbackend_ip_forwarding_check_error") ) func (b *LocalBackend) stateEncrypted() opt.Bool { diff --git a/ipn/ipnlocal/local_test.go b/ipn/ipnlocal/local_test.go index 259e4b6b28a83..70cbc89914df7 100644 --- a/ipn/ipnlocal/local_test.go +++ b/ipn/ipnlocal/local_test.go @@ -61,7 +61,6 @@ import ( "tailscale.com/types/netmap" "tailscale.com/types/opt" "tailscale.com/types/persist" - "tailscale.com/types/ptr" "tailscale.com/types/views" "tailscale.com/util/dnsname" "tailscale.com/util/eventbus" @@ -711,6 +710,118 @@ func TestLoadCachedNetMap(t *testing.T) { } } +func TestUpdateNetMapCache(t *testing.T) { + t.Setenv("TS_USE_CACHED_NETMAP", "1") + + // Set up a cache directory so we can check what happens to it, in response + // to netmap updates. + varRoot := t.TempDir() + cacheDir := filepath.Join(varRoot, "profile-data", "id0", "netmap-cache") + + testMap := &netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{ + Name: "example.ts.net", + User: tailcfg.UserID(1), + Addresses: []netip.Prefix{ + netip.MustParsePrefix("100.2.3.4/32"), + }, + }).View(), + UserProfiles: map[tailcfg.UserID]tailcfg.UserProfileView{ + tailcfg.UserID(1): (&tailcfg.UserProfile{ + ID: 1, + LoginName: "amelie@example.com", + DisplayName: "Amelie du Pangoline", + }).View(), + }, + Peers: []tailcfg.NodeView{ + (&tailcfg.Node{ + ID: 601, + StableID: "n601FAKE", + ComputedName: "some-peer", + User: tailcfg.UserID(1), + Key: makeNodeKeyFromID(601), + Addresses: []netip.Prefix{ + netip.MustParsePrefix("100.2.3.5/32"), + }, + }).View(), + }, + } + + // Make a new backend to which we can send network maps to test that + // netmap caching decisions are made appropriately. + sys := tsd.NewSystem() + e, err := wgengine.NewFakeUserspaceEngine(logger.Discard, + sys.Set, + sys.HealthTracker.Get(), + sys.UserMetricsRegistry(), + sys.Bus.Get(), + ) + if err != nil { + t.Fatalf("Make userspace engine: %v", err) + } + t.Cleanup(e.Close) + sys.Set(e) + sys.Set(new(mem.Store)) + + logf := tstest.WhileTestRunningLogger(t) + clb, err := NewLocalBackend(logf, logid.PublicID{}, sys, 0) + if err != nil { + t.Fatalf("Make local backend: %v", err) + } + t.Cleanup(clb.Shutdown) + clb.SetVarRoot(varRoot) + + pm := must.Get(newProfileManager(new(mem.Store), logf, health.NewTracker(sys.Bus.Get()))) + pm.currentProfile = (&ipn.LoginProfile{ID: "id0"}).View() + clb.pm = pm + if err := clb.Start(ipn.Options{}); err != nil { + t.Fatalf("Start local backend: %v", err) + } + + wantCacheEmpty := func() { + // The cache directory should be empty, as caching is not enabled. + if des, err := os.ReadDir(cacheDir); err != nil { + t.Errorf("List cache directory: %v", err) + } else if len(des) != 0 { + t.Errorf("Cache directory has %d items, want 0\n%+v", len(des), des) + } + } + + // Send the initial network map to the backend. Because the map does not + // include the cache attribute, no cache should be written. + clb.mu.Lock() + clb.setNetMapLocked(testMap) + clb.mu.Unlock() + + wantCacheEmpty() + + // Now enable the netmap caching attribute, and send another update. + // After doing so, the cache should have real data in it. + testMap.AllCaps = set.Of(tailcfg.NodeAttrCacheNetworkMaps) + + clb.mu.Lock() + clb.setNetMapLocked(testMap) + clb.mu.Unlock() + + if des, err := os.ReadDir(cacheDir); err != nil { + t.Errorf("List cache directory: %v", err) + } else if len(des) == 0 { + t.Error("Cache is unexpectedly empty") + } else { + t.Logf("Cache directory has %d entries (OK)", len(des)) + } + + // Now disable the node attribute again, send another update, and verify + // that the cache got cleaned up. + testMap.AllCaps = nil + + clb.mu.Lock() + clb.setNetMapLocked(testMap) + clb.mu.Unlock() + + wantCacheEmpty() +} + func TestConfigureExitNode(t *testing.T) { controlURL := "https://localhost:1/" exitNode1 := makeExitNode(1, withName("node-1"), withDERP(1), withAddresses(netip.MustParsePrefix("100.64.1.1/32"))) @@ -877,7 +988,7 @@ func TestConfigureExitNode(t *testing.T) { Prefs: ipn.Prefs{AutoExitNode: "any"}, AutoExitNodeSet: true, }, - useExitNodeEnabled: ptr.To(false), + useExitNodeEnabled: new(false), wantPrefs: ipn.Prefs{ ControlURL: controlURL, ExitNodeID: "", @@ -894,7 +1005,7 @@ func TestConfigureExitNode(t *testing.T) { }, netMap: clientNetmap, report: report, - useExitNodeEnabled: ptr.To(true), + useExitNodeEnabled: new(true), wantPrefs: ipn.Prefs{ ControlURL: controlURL, ExitNodeID: exitNode1.StableID(), @@ -909,7 +1020,7 @@ func TestConfigureExitNode(t *testing.T) { ControlURL: controlURL, }, netMap: clientNetmap, - exitNodeIDPolicy: ptr.To(exitNode1.StableID()), + exitNodeIDPolicy: new(exitNode1.StableID()), wantPrefs: ipn.Prefs{ ControlURL: controlURL, ExitNodeID: exitNode1.StableID(), @@ -922,7 +1033,7 @@ func TestConfigureExitNode(t *testing.T) { ControlURL: controlURL, }, netMap: clientNetmap, - exitNodeIDPolicy: ptr.To(exitNode1.StableID()), + exitNodeIDPolicy: new(exitNode1.StableID()), changePrefs: &ipn.MaskedPrefs{ Prefs: ipn.Prefs{ ExitNodeID: exitNode2.StableID(), // this should be ignored @@ -942,7 +1053,7 @@ func TestConfigureExitNode(t *testing.T) { ControlURL: controlURL, }, netMap: clientNetmap, - exitNodeIDPolicy: ptr.To(exitNode1.StableID()), + exitNodeIDPolicy: new(exitNode1.StableID()), changePrefs: &ipn.MaskedPrefs{ Prefs: ipn.Prefs{ ExitNodeIP: exitNode2.Addresses().At(0).Addr(), // this should be ignored @@ -962,7 +1073,7 @@ func TestConfigureExitNode(t *testing.T) { ControlURL: controlURL, }, netMap: clientNetmap, - exitNodeIDPolicy: ptr.To(exitNode1.StableID()), + exitNodeIDPolicy: new(exitNode1.StableID()), changePrefs: &ipn.MaskedPrefs{ Prefs: ipn.Prefs{ AutoExitNode: "any", // this should be ignored @@ -982,7 +1093,7 @@ func TestConfigureExitNode(t *testing.T) { ControlURL: controlURL, }, netMap: clientNetmap, - exitNodeIPPolicy: ptr.To(exitNode2.Addresses().At(0).Addr()), + exitNodeIPPolicy: new(exitNode2.Addresses().At(0).Addr()), wantPrefs: ipn.Prefs{ ControlURL: controlURL, ExitNodeID: exitNode2.StableID(), @@ -996,7 +1107,7 @@ func TestConfigureExitNode(t *testing.T) { }, netMap: clientNetmap, report: report, - exitNodeIDPolicy: ptr.To(tailcfg.StableNodeID("auto:any")), + exitNodeIDPolicy: new(tailcfg.StableNodeID("auto:any")), wantPrefs: ipn.Prefs{ ControlURL: controlURL, ExitNodeID: exitNode1.StableID(), @@ -1011,7 +1122,7 @@ func TestConfigureExitNode(t *testing.T) { }, netMap: clientNetmap, report: nil, - exitNodeIDPolicy: ptr.To(tailcfg.StableNodeID("auto:any")), + exitNodeIDPolicy: new(tailcfg.StableNodeID("auto:any")), wantPrefs: ipn.Prefs{ ControlURL: controlURL, ExitNodeID: unresolvedExitNodeID, @@ -1026,7 +1137,7 @@ func TestConfigureExitNode(t *testing.T) { }, netMap: nil, report: report, - exitNodeIDPolicy: ptr.To(tailcfg.StableNodeID("auto:any")), + exitNodeIDPolicy: new(tailcfg.StableNodeID("auto:any")), wantPrefs: ipn.Prefs{ ControlURL: controlURL, ExitNodeID: unresolvedExitNodeID, @@ -1042,7 +1153,7 @@ func TestConfigureExitNode(t *testing.T) { }, netMap: nil, report: report, - exitNodeIDPolicy: ptr.To(tailcfg.StableNodeID("auto:any")), + exitNodeIDPolicy: new(tailcfg.StableNodeID("auto:any")), exitNodeAllowedIDs: nil, // not configured, so all exit node IDs are implicitly allowed wantPrefs: ipn.Prefs{ ControlURL: controlURL, @@ -1059,7 +1170,7 @@ func TestConfigureExitNode(t *testing.T) { }, netMap: nil, report: report, - exitNodeIDPolicy: ptr.To(tailcfg.StableNodeID("auto:any")), + exitNodeIDPolicy: new(tailcfg.StableNodeID("auto:any")), exitNodeAllowedIDs: []tailcfg.StableNodeID{ exitNode2.StableID(), // the current exit node ID is allowed }, @@ -1078,7 +1189,7 @@ func TestConfigureExitNode(t *testing.T) { }, netMap: nil, report: report, - exitNodeIDPolicy: ptr.To(tailcfg.StableNodeID("auto:any")), + exitNodeIDPolicy: new(tailcfg.StableNodeID("auto:any")), exitNodeAllowedIDs: []tailcfg.StableNodeID{ exitNode1.StableID(), // a different exit node ID; the current one is not allowed }, @@ -1097,7 +1208,7 @@ func TestConfigureExitNode(t *testing.T) { }, netMap: clientNetmap, report: report, - exitNodeIDPolicy: ptr.To(tailcfg.StableNodeID("auto:any")), + exitNodeIDPolicy: new(tailcfg.StableNodeID("auto:any")), exitNodeAllowedIDs: []tailcfg.StableNodeID{ exitNode2.StableID(), // a different exit node ID; the current one is not allowed }, @@ -1116,7 +1227,7 @@ func TestConfigureExitNode(t *testing.T) { }, netMap: clientNetmap, report: report, - exitNodeIDPolicy: ptr.To(tailcfg.StableNodeID("auto:any")), + exitNodeIDPolicy: new(tailcfg.StableNodeID("auto:any")), wantPrefs: ipn.Prefs{ ControlURL: controlURL, ExitNodeID: exitNode1.StableID(), // switch to the best exit node @@ -1131,7 +1242,7 @@ func TestConfigureExitNode(t *testing.T) { }, netMap: clientNetmap, report: report, - exitNodeIDPolicy: ptr.To(tailcfg.StableNodeID("auto:foo")), + exitNodeIDPolicy: new(tailcfg.StableNodeID("auto:foo")), wantPrefs: ipn.Prefs{ ControlURL: controlURL, ExitNodeID: exitNode1.StableID(), // unknown exit node expressions should work as "any" @@ -1164,8 +1275,8 @@ func TestConfigureExitNode(t *testing.T) { }, netMap: clientNetmap, report: report, - exitNodeIDPolicy: ptr.To(tailcfg.StableNodeID("auto:any")), - useExitNodeEnabled: ptr.To(false), // should fail with an error + exitNodeIDPolicy: new(tailcfg.StableNodeID("auto:any")), + useExitNodeEnabled: new(false), // should fail with an error wantExitNodeToggleErr: errManagedByPolicy, wantPrefs: ipn.Prefs{ ControlURL: controlURL, @@ -1182,7 +1293,7 @@ func TestConfigureExitNode(t *testing.T) { }, netMap: clientNetmap, report: report, - exitNodeIDPolicy: ptr.To(tailcfg.StableNodeID("auto:any")), + exitNodeIDPolicy: new(tailcfg.StableNodeID("auto:any")), exitNodeAllowOverride: true, // allow changing the exit node changePrefs: &ipn.MaskedPrefs{ Prefs: ipn.Prefs{ @@ -1204,7 +1315,7 @@ func TestConfigureExitNode(t *testing.T) { }, netMap: clientNetmap, report: report, - exitNodeIDPolicy: ptr.To(tailcfg.StableNodeID("auto:any")), + exitNodeIDPolicy: new(tailcfg.StableNodeID("auto:any")), exitNodeAllowOverride: true, // allow changing, but not disabling, the exit node changePrefs: &ipn.MaskedPrefs{ Prefs: ipn.Prefs{ @@ -1228,9 +1339,9 @@ func TestConfigureExitNode(t *testing.T) { }, netMap: clientNetmap, report: report, - exitNodeIDPolicy: ptr.To(tailcfg.StableNodeID("auto:any")), - exitNodeAllowOverride: true, // allow changing, but not disabling, the exit node - useExitNodeEnabled: ptr.To(false), // should fail with an error + exitNodeIDPolicy: new(tailcfg.StableNodeID("auto:any")), + exitNodeAllowOverride: true, // allow changing, but not disabling, the exit node + useExitNodeEnabled: new(false), // should fail with an error wantExitNodeToggleErr: errManagedByPolicy, wantPrefs: ipn.Prefs{ ControlURL: controlURL, @@ -1992,15 +2103,15 @@ func TestUpdateNetmapDelta(t *testing.T) { }, { NodeID: 2, - Online: ptr.To(true), + Online: new(true), }, { NodeID: 3, - Online: ptr.To(false), + Online: new(false), }, { NodeID: 4, - LastSeen: ptr.To(someTime), + LastSeen: new(someTime), }, }, }, someTime) @@ -2021,17 +2132,17 @@ func TestUpdateNetmapDelta(t *testing.T) { { ID: 2, Key: makeNodeKeyFromID(2), - Online: ptr.To(true), + Online: new(true), }, { ID: 3, Key: makeNodeKeyFromID(3), - Online: ptr.To(false), + Online: new(false), }, { ID: 4, Key: makeNodeKeyFromID(4), - LastSeen: ptr.To(someTime), + LastSeen: new(someTime), }, } for _, want := range wants { @@ -2047,9 +2158,53 @@ func TestUpdateNetmapDelta(t *testing.T) { } } -// tests WhoIs and indirectly that setNetMapLocked updates b.nodeByAddr correctly. +type whoIsTestParams struct { + testName string + q string + want tailcfg.NodeID // 0 means want ok=false + wantName string + wantGroups []string +} + +func expectWhois(t *testing.T, tests []whoIsTestParams, b *LocalBackend) { + t.Helper() + + checkWhoIs := func(t *testing.T, tt whoIsTestParams, nv tailcfg.NodeView, up tailcfg.UserProfile, ok bool) { + t.Helper() + var got tailcfg.NodeID + if ok { + got = nv.ID() + } + if got != tt.want { + t.Errorf("got nodeID %v; want %v", got, tt.want) + } + if up.DisplayName != tt.wantName { + t.Errorf("got name %q; want %q", up.DisplayName, tt.wantName) + } + if !slices.Equal(up.Groups, tt.wantGroups) { + t.Errorf("got groups %q; want %q", up.Groups, tt.wantGroups) + } + } + + for _, tt := range tests { + t.Run("ByAddr/"+tt.testName, func(t *testing.T) { + nv, up, ok := b.WhoIs("", netip.MustParseAddrPort(tt.q)) + checkWhoIs(t, tt, nv, up, ok) + }) + t.Run("ByNodeKey/"+tt.testName, func(t *testing.T) { + nv, up, ok := b.WhoIsNodeKey(makeNodeKeyFromID(tt.want)) + checkWhoIs(t, tt, nv, up, ok) + }) + } +} + +// Test WhoIs and WhoIsNodeKey. +// This indirectly asserts that localBackend's setNetMapLocked updates nodeBackend's b.nodeByAddr and b.nodeByKey correctly. func TestWhoIs(t *testing.T) { + b := newTestLocalBackend(t) + + // Simple two-node netmap. b.setNetMapLocked(&netmap.NetworkMap{ SelfNode: (&tailcfg.Node{ ID: 1, @@ -2070,36 +2225,106 @@ func TestWhoIs(t *testing.T) { DisplayName: "Myself", }).View(), 20: (&tailcfg.UserProfile{ - DisplayName: "Peer", + DisplayName: "Peer2", + Groups: []string{"group:foo"}, }).View(), }, }) - tests := []struct { - q string - want tailcfg.NodeID // 0 means want ok=false - wantName string - }{ - {"100.101.102.103:0", 1, "Myself"}, - {"100.101.102.103:123", 1, "Myself"}, - {"100.200.200.200:0", 2, "Peer"}, - {"100.200.200.200:123", 2, "Peer"}, - {"100.4.0.4:404", 0, ""}, + testsRound1 := []whoIsTestParams{ + {"round1MyselfNoPort", "100.101.102.103:0", 1, "Myself", nil}, + {"round1MyselfWithPort", "100.101.102.103:123", 1, "Myself", nil}, + {"round1Peer2NoPort", "100.200.200.200:0", 2, "Peer2", []string{"group:foo"}}, + {"round1Peer2WithPort", "100.200.200.200:123", 2, "Peer2", []string{"group:foo"}}, + {"round1UnknownPeer", "100.4.0.4:404", 0, "", nil}, } - for _, tt := range tests { - t.Run(tt.q, func(t *testing.T) { - nv, up, ok := b.WhoIs("", netip.MustParseAddrPort(tt.q)) - var got tailcfg.NodeID - if ok { - got = nv.ID() - } - if got != tt.want { - t.Errorf("got nodeID %v; want %v", got, tt.want) - } - if up.DisplayName != tt.wantName { - t.Errorf("got name %q; want %q", up.DisplayName, tt.wantName) - } - }) + expectWhois(t, testsRound1, b) + + // Now push a new netmap where a new peer is added + // This verifies we add nodes to indexes correctly + b.setNetMapLocked(&netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{ + ID: 1, + User: 10, + Key: makeNodeKeyFromID(1), + Addresses: []netip.Prefix{netip.MustParsePrefix("100.101.102.103/32")}, + }).View(), + Peers: []tailcfg.NodeView{ + (&tailcfg.Node{ + ID: 2, + User: 20, + Key: makeNodeKeyFromID(2), + Addresses: []netip.Prefix{netip.MustParsePrefix("100.200.200.200/32")}, + }).View(), + (&tailcfg.Node{ + ID: 3, + User: 30, + Key: makeNodeKeyFromID(3), + Addresses: []netip.Prefix{netip.MustParsePrefix("100.233.233.233/32")}, + }).View(), + }, + UserProfiles: map[tailcfg.UserID]tailcfg.UserProfileView{ + 10: (&tailcfg.UserProfile{ + DisplayName: "Myself", + }).View(), + 20: (&tailcfg.UserProfile{ + DisplayName: "Peer2", + Groups: []string{"group:foo"}, + }).View(), + 30: (&tailcfg.UserProfile{ + DisplayName: "Peer3", + }).View(), + }, + }) + + testsRound2 := []whoIsTestParams{ + {"round2MyselfNoPort", "100.101.102.103:0", 1, "Myself", nil}, + {"round2MyselfWithPort", "100.101.102.103:123", 1, "Myself", nil}, + {"round2Peer2NoPort", "100.200.200.200:0", 2, "Peer2", []string{"group:foo"}}, + {"round2Peer2WithPort", "100.200.200.200:123", 2, "Peer2", []string{"group:foo"}}, + {"round2Peer3NoPort", "100.233.233.233:0", 3, "Peer3", nil}, + {"round2Peer3WithPort", "100.233.233.233:123", 3, "Peer3", nil}, + {"round2UnknownPeer", "100.4.0.4:404", 0, "", nil}, + } + expectWhois(t, testsRound2, b) + + // Finally push a new netmap where a peer is removed + // This verifies we remove nodes from indexes correctly + b.setNetMapLocked(&netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{ + ID: 1, + User: 10, + Key: makeNodeKeyFromID(1), + Addresses: []netip.Prefix{netip.MustParsePrefix("100.101.102.103/32")}, + }).View(), + Peers: []tailcfg.NodeView{ + // Node ID 2 removed + (&tailcfg.Node{ + ID: 3, + User: 30, + Key: makeNodeKeyFromID(3), + Addresses: []netip.Prefix{netip.MustParsePrefix("100.233.233.233/32")}, + }).View(), + }, + UserProfiles: map[tailcfg.UserID]tailcfg.UserProfileView{ + 10: (&tailcfg.UserProfile{ + DisplayName: "Myself", + }).View(), + 30: (&tailcfg.UserProfile{ + DisplayName: "Peer3", + }).View(), + }, + }) + + testsRound3 := []whoIsTestParams{ + {"round3MyselfNoPort", "100.101.102.103:0", 1, "Myself", nil}, + {"round3MyselfWithPort", "100.101.102.103:123", 1, "Myself", nil}, + {"round3Peer2NoPortUnknown", "100.200.200.200:0", 0, "", nil}, + {"round3Peer2WithPortUnknown", "100.200.200.200:123", 0, "", nil}, + {"round3Peer3NoPort", "100.233.233.233:0", 3, "Peer3", nil}, + {"round3Peer3WithPort", "100.233.233.233:123", 3, "Peer3", nil}, + {"round3UnknownPeer", "100.4.0.4:404", 0, "", nil}, } + expectWhois(t, testsRound3, b) } func TestWireguardExitNodeDNSResolvers(t *testing.T) { @@ -2810,20 +3035,20 @@ func TestSetExitNodeIDPolicy(t *testing.T) { lastSuggestedExitNode tailcfg.StableNodeID }{ { - name: "ExitNodeID key is set", + name: "exitNodeID-set", exitNodeIDKey: true, exitNodeID: "123", exitNodeIDWant: "123", prefsChanged: true, }, { - name: "ExitNodeID key not set", + name: "exitNodeID-not-set", exitNodeIDKey: true, exitNodeIDWant: "", prefsChanged: false, }, { - name: "ExitNodeID key set, ExitNodeIP preference set", + name: "exitNodeID-set-exitNodeIP-pref-set", exitNodeIDKey: true, exitNodeID: "123", prefs: &ipn.Prefs{ExitNodeIP: netip.MustParseAddr("127.0.0.1")}, @@ -2831,7 +3056,7 @@ func TestSetExitNodeIDPolicy(t *testing.T) { prefsChanged: true, }, { - name: "ExitNodeID key not set, ExitNodeIP key set", + name: "exitNodeID-not-set-exitNodeIP-set", exitNodeIPKey: true, exitNodeIP: "127.0.0.1", prefs: &ipn.Prefs{ExitNodeIP: netip.MustParseAddr("127.0.0.1")}, @@ -2839,7 +3064,7 @@ func TestSetExitNodeIDPolicy(t *testing.T) { prefsChanged: false, }, { - name: "ExitNodeIP key set, existing ExitNodeIP pref", + name: "exitNodeIP-set-existing-pref", exitNodeIPKey: true, exitNodeIP: "127.0.0.1", prefs: &ipn.Prefs{ExitNodeIP: netip.MustParseAddr("127.0.0.1")}, @@ -2847,7 +3072,7 @@ func TestSetExitNodeIDPolicy(t *testing.T) { prefsChanged: false, }, { - name: "existing preferences match policy", + name: "existing-prefs-match-policy", exitNodeIDKey: true, exitNodeID: "123", prefs: &ipn.Prefs{ExitNodeID: tailcfg.StableNodeID("123")}, @@ -2855,7 +3080,8 @@ func TestSetExitNodeIDPolicy(t *testing.T) { prefsChanged: false, }, { - name: "ExitNodeIP set if net map does not have corresponding node", + // ExitNodeIP is set when net map does not have a corresponding node. + name: "exitNodeIP-set-no-matching-node", exitNodeIPKey: true, prefs: &ipn.Prefs{ExitNodeIP: netip.MustParseAddr("127.0.0.1")}, exitNodeIP: "127.0.0.1", @@ -2891,7 +3117,8 @@ func TestSetExitNodeIDPolicy(t *testing.T) { }, }, { - name: "ExitNodeIP cleared if net map has corresponding node - policy matches prefs", + // ExitNodeIP cleared when net map has corresponding node and policy matches prefs. + name: "exitNodeIP-cleared-matching-node-policy-matches", prefs: &ipn.Prefs{ExitNodeIP: netip.MustParseAddr("127.0.0.1")}, exitNodeIPKey: true, exitNodeIP: "127.0.0.1", @@ -2931,7 +3158,8 @@ func TestSetExitNodeIDPolicy(t *testing.T) { }, }, { - name: "ExitNodeIP cleared if net map has corresponding node - no policy set", + // ExitNodeIP cleared when net map has corresponding node and no policy is set. + name: "exitNodeIP-cleared-matching-node-no-policy", prefs: &ipn.Prefs{ExitNodeIP: netip.MustParseAddr("127.0.0.1")}, exitNodeIPWant: "", exitNodeIDWant: "123", @@ -2969,7 +3197,8 @@ func TestSetExitNodeIDPolicy(t *testing.T) { }, }, { - name: "ExitNodeIP cleared if net map has corresponding node - different exit node IP in policy", + // ExitNodeIP cleared when net map has corresponding node but policy has different exit node IP. + name: "exitNodeIP-cleared-matching-node-different-policy-IP", exitNodeIPKey: true, prefs: &ipn.Prefs{ExitNodeIP: netip.MustParseAddr("127.0.0.1")}, exitNodeIP: "100.64.5.6", @@ -3009,7 +3238,7 @@ func TestSetExitNodeIDPolicy(t *testing.T) { }, }, { - name: "ExitNodeID key is set to auto:any and last suggested exit node is populated", + name: "exitNodeID-auto-any-last-suggested-populated", exitNodeIDKey: true, exitNodeID: "auto:any", lastSuggestedExitNode: "123", @@ -3018,7 +3247,7 @@ func TestSetExitNodeIDPolicy(t *testing.T) { prefsChanged: true, }, { - name: "ExitNodeID key is set to auto:any and last suggested exit node is not populated", + name: "exitNodeID-auto-any-last-suggested-not-populated", exitNodeIDKey: true, exitNodeID: "auto:any", exitNodeIDWant: "auto:any", @@ -3026,7 +3255,7 @@ func TestSetExitNodeIDPolicy(t *testing.T) { prefsChanged: true, }, { - name: "ExitNodeID key is set to auto:foo and last suggested exit node is populated", + name: "exitNodeID-auto-foo-last-suggested-populated", exitNodeIDKey: true, exitNodeID: "auto:foo", lastSuggestedExitNode: "123", @@ -3035,7 +3264,7 @@ func TestSetExitNodeIDPolicy(t *testing.T) { prefsChanged: true, }, { - name: "ExitNodeID key is set to auto:foo and last suggested exit node is not populated", + name: "exitNodeID-auto-foo-last-suggested-not-populated", exitNodeIDKey: true, exitNodeID: "auto:foo", exitNodeIDWant: "auto:any", // should be "auto:any" for compatibility with existing clients @@ -3149,11 +3378,11 @@ func TestUpdateNetmapDeltaAutoExitNode(t *testing.T) { muts: []*tailcfg.PeerChange{ { NodeID: 1, - Online: ptr.To(true), + Online: new(true), }, { NodeID: 2, - Online: ptr.To(false), // the selected exit node goes offline + Online: new(false), // the selected exit node goes offline }, }, exitNodeIDWant: peer1.StableID(), @@ -3173,11 +3402,11 @@ func TestUpdateNetmapDeltaAutoExitNode(t *testing.T) { muts: []*tailcfg.PeerChange{ { NodeID: 1, - Online: ptr.To(false), // a different exit node goes offline + Online: new(false), // a different exit node goes offline }, { NodeID: 2, - Online: ptr.To(true), + Online: new(true), }, }, exitNodeIDWant: peer2.StableID(), @@ -3420,10 +3649,10 @@ func TestApplySysPolicy(t *testing.T) { stringPolicies map[pkey.Key]string }{ { - name: "empty prefs without policies", + name: "empty-prefs-no-policies", }, { - name: "prefs set without policies", + name: "prefs-set-no-policies", prefs: ipn.Prefs{ ControlURL: "1", ShieldsUp: true, @@ -3442,7 +3671,7 @@ func TestApplySysPolicy(t *testing.T) { }, }, { - name: "empty prefs with policies", + name: "empty-prefs-with-policies", wantPrefs: ipn.Prefs{ ControlURL: "1", ShieldsUp: true, @@ -3462,7 +3691,7 @@ func TestApplySysPolicy(t *testing.T) { }, }, { - name: "prefs set with matching policies", + name: "prefs-set-matching-policies", prefs: ipn.Prefs{ ControlURL: "1", ShieldsUp: true, @@ -3483,7 +3712,7 @@ func TestApplySysPolicy(t *testing.T) { }, }, { - name: "prefs set with conflicting policies", + name: "prefs-set-conflicting-policies", prefs: ipn.Prefs{ ControlURL: "1", ShieldsUp: true, @@ -3511,7 +3740,7 @@ func TestApplySysPolicy(t *testing.T) { }, }, { - name: "prefs set with neutral policies", + name: "prefs-set-neutral-policies", prefs: ipn.Prefs{ ControlURL: "1", ShieldsUp: true, @@ -3547,7 +3776,7 @@ func TestApplySysPolicy(t *testing.T) { }, }, { - name: "enable AutoUpdate apply does not unset check", + name: "enable-apply-keeps-check", prefs: ipn.Prefs{ AutoUpdate: ipn.AutoUpdatePrefs{ Check: true, @@ -3566,7 +3795,7 @@ func TestApplySysPolicy(t *testing.T) { }, }, { - name: "disable AutoUpdate apply does not unset check", + name: "disable-apply-keeps-check", prefs: ipn.Prefs{ AutoUpdate: ipn.AutoUpdatePrefs{ Check: true, @@ -3585,7 +3814,7 @@ func TestApplySysPolicy(t *testing.T) { }, }, { - name: "enable AutoUpdate check does not unset apply", + name: "enable-check-keeps-apply", prefs: ipn.Prefs{ AutoUpdate: ipn.AutoUpdatePrefs{ Check: false, @@ -3604,7 +3833,7 @@ func TestApplySysPolicy(t *testing.T) { }, }, { - name: "disable AutoUpdate check does not unset apply", + name: "disable-check-keeps-apply", prefs: ipn.Prefs{ AutoUpdate: ipn.AutoUpdatePrefs{ Check: true, @@ -3654,7 +3883,7 @@ func TestApplySysPolicy(t *testing.T) { } }) - t.Run("status update", func(t *testing.T) { + t.Run("status-update", func(t *testing.T) { // Profile manager fills in blank ControlURL but it's not set // in most test cases to avoid cluttering them, so adjust for // that. @@ -3694,75 +3923,75 @@ func TestPreferencePolicyInfo(t *testing.T) { policyError error }{ { - name: "force enable modify", + name: "force-enable-modify", initialValue: false, wantValue: true, wantChange: true, policyValue: "always", }, { - name: "force enable unchanged", + name: "force-enable-unchanged", initialValue: true, wantValue: true, policyValue: "always", }, { - name: "force disable modify", + name: "force-disable-modify", initialValue: true, wantValue: false, wantChange: true, policyValue: "never", }, { - name: "force disable unchanged", + name: "force-disable-unchanged", initialValue: false, wantValue: false, policyValue: "never", }, { - name: "unforced enabled", + name: "unforced-enabled", initialValue: true, wantValue: true, policyValue: "user-decides", }, { - name: "unforced disabled", + name: "unforced-disabled", initialValue: false, wantValue: false, policyValue: "user-decides", }, { - name: "blank enabled", + name: "blank-enabled", initialValue: true, wantValue: true, policyValue: "", }, { - name: "blank disabled", + name: "blank-disabled", initialValue: false, wantValue: false, policyValue: "", }, { - name: "unset enabled", + name: "unset-enabled", initialValue: true, wantValue: true, policyError: syspolicy.ErrNoSuchKey, }, { - name: "unset disabled", + name: "unset-disabled", initialValue: false, wantValue: false, policyError: syspolicy.ErrNoSuchKey, }, { - name: "error enabled", + name: "error-enabled", initialValue: true, wantValue: true, policyError: errors.New("test error"), }, { - name: "error disabled", + name: "error-disabled", initialValue: false, wantValue: false, policyError: errors.New("test error"), @@ -3888,53 +4117,62 @@ func TestOnTailnetDefaultAutoUpdate(t *testing.T) { func TestTCPHandlerForDst(t *testing.T) { b := newTestBackend(t) tests := []struct { + name string desc string dst string intercept bool }{ { + name: "100_100_100_100-port80", desc: "intercept port 80 (Web UI) on quad100 IPv4", dst: "100.100.100.100:80", intercept: true, }, { + name: "fd7a-115c-a1e0--53-port80", desc: "intercept port 80 (Web UI) on quad100 IPv6", dst: "[fd7a:115c:a1e0::53]:80", intercept: true, }, { + name: "100_100_103_100-port80", desc: "don't intercept port 80 on local ip", dst: "100.100.103.100:80", intercept: false, }, { + name: "fd7a-115c-a1e0--53-port8080", desc: "intercept port 8080 (Taildrive) on quad100 IPv4", dst: "[fd7a:115c:a1e0::53]:8080", intercept: true, }, { + name: "100_100_103_100-port8080", desc: "don't intercept port 8080 on local ip", dst: "100.100.103.100:8080", intercept: false, }, { + name: "100_100_100_100-port9080", desc: "don't intercept port 9080 on quad100 IPv4", dst: "100.100.100.100:9080", intercept: false, }, { + name: "fd7a-115c-a1e0--53-port9080", desc: "don't intercept port 9080 on quad100 IPv6", dst: "[fd7a:115c:a1e0::53]:9080", intercept: false, }, { + name: "100_100_103_100-port9080", desc: "don't intercept port 9080 on local ip", dst: "100.100.103.100:9080", intercept: false, }, } for _, tt := range tests { - t.Run(tt.dst, func(t *testing.T) { + t.Run(tt.name, func(t *testing.T) { t.Log(tt.desc) src := netip.MustParseAddrPort("100.100.102.100:51234") h, _ := b.TCPHandlerForDst(src, netip.MustParseAddrPort(tt.dst)) @@ -4033,122 +4271,146 @@ func TestTCPHandlerForDstWithVIPService(t *testing.T) { } tests := []struct { + name string desc string dst string intercept bool }{ { + name: "100_100_100_100-port80", desc: "intercept port 80 (Web UI) on quad100 IPv4", dst: "100.100.100.100:80", intercept: true, }, { + name: "fd7a-115c-a1e0--53-port80", desc: "intercept port 80 (Web UI) on quad100 IPv6", dst: "[fd7a:115c:a1e0::53]:80", intercept: true, }, { + name: "100_100_103_100-port80", desc: "don't intercept port 80 on local ip", dst: "100.100.103.100:80", intercept: false, }, { + name: "100_100_100_100-port8080", desc: "intercept port 8080 (Taildrive) on quad100 IPv4", dst: "100.100.100.100:8080", intercept: true, }, { + name: "fd7a-115c-a1e0--53-port8080", desc: "intercept port 8080 (Taildrive) on quad100 IPv6", dst: "[fd7a:115c:a1e0::53]:8080", intercept: true, }, { + name: "100_100_103_100-port8080", desc: "don't intercept port 8080 on local ip", dst: "100.100.103.100:8080", intercept: false, }, { + name: "100_100_100_100-port9080", desc: "don't intercept port 9080 on quad100 IPv4", dst: "100.100.100.100:9080", intercept: false, }, { + name: "fd7a-115c-a1e0--53-port9080", desc: "don't intercept port 9080 on quad100 IPv6", dst: "[fd7a:115c:a1e0::53]:9080", intercept: false, }, { + name: "100_100_103_100-port9080", desc: "don't intercept port 9080 on local ip", dst: "100.100.103.100:9080", intercept: false, }, // VIP service destinations { + name: "100_101_101_101-port882", desc: "intercept port 882 (HTTP) on service foo IPv4", dst: "100.101.101.101:882", intercept: true, }, { + name: "fd7a-115c-a1e0-ab12-4843-cd96-6565-6565-port882", desc: "intercept port 882 (HTTP) on service foo IPv6", dst: "[fd7a:115c:a1e0:ab12:4843:cd96:6565:6565]:882", intercept: true, }, { + name: "100_101_101_101-port883", desc: "intercept port 883 (HTTPS) on service foo IPv4", dst: "100.101.101.101:883", intercept: true, }, { + name: "fd7a-115c-a1e0-ab12-4843-cd96-6565-6565-port883", desc: "intercept port 883 (HTTPS) on service foo IPv6", dst: "[fd7a:115c:a1e0:ab12:4843:cd96:6565:6565]:883", intercept: true, }, { + name: "100_99_99_99-port990", desc: "intercept port 990 (TCPForward) on service bar IPv4", dst: "100.99.99.99:990", intercept: true, }, { + name: "fd7a-115c-a1e0-ab12-4843-cd96-626b-628b-port990", desc: "intercept port 990 (TCPForward) on service bar IPv6", dst: "[fd7a:115c:a1e0:ab12:4843:cd96:626b:628b]:990", intercept: true, }, { + name: "100_99_99_99-port990-terminateTLS", desc: "intercept port 991 (TCPForward with TerminateTLS) on service bar IPv4", dst: "100.99.99.99:990", intercept: true, }, { + name: "fd7a-115c-a1e0-ab12-4843-cd96-626b-628b-port990-terminateTLS", desc: "intercept port 991 (TCPForward with TerminateTLS) on service bar IPv6", dst: "[fd7a:115c:a1e0:ab12:4843:cd96:626b:628b]:990", intercept: true, }, { + name: "100_101_101_101-port4444", desc: "don't intercept port 4444 on service foo IPv4", dst: "100.101.101.101:4444", intercept: false, }, { + name: "fd7a-115c-a1e0-ab12-4843-cd96-6565-6565-port4444", desc: "don't intercept port 4444 on service foo IPv6", dst: "[fd7a:115c:a1e0:ab12:4843:cd96:6565:6565]:4444", intercept: false, }, { + name: "100_22_22_22-port883", desc: "don't intercept port 600 on unknown service IPv4", dst: "100.22.22.22:883", intercept: false, }, { + name: "fd7a-115c-a1e0-ab12-4843-cd96-626b-628b-port883", desc: "don't intercept port 600 on unknown service IPv6", dst: "[fd7a:115c:a1e0:ab12:4843:cd96:626b:628b]:883", intercept: false, }, { + name: "100_133_133_133-port600", desc: "don't intercept port 600 (HTTPS) on service baz IPv4", dst: "100.133.133.133:600", intercept: false, }, { + name: "fd7a-115c-a1e0-ab12-4843-cd96-8585-8585-port600", desc: "don't intercept port 600 (HTTPS) on service baz IPv6", dst: "[fd7a:115c:a1e0:ab12:4843:cd96:8585:8585]:600", intercept: false, @@ -4156,7 +4418,7 @@ func TestTCPHandlerForDstWithVIPService(t *testing.T) { } for _, tt := range tests { - t.Run(tt.dst, func(t *testing.T) { + t.Run(tt.name, func(t *testing.T) { t.Log(tt.desc) src := netip.MustParseAddrPort("100.100.102.100:51234") h, _ := b.TCPHandlerForDst(src, netip.MustParseAddrPort(tt.dst)) @@ -4326,7 +4588,7 @@ func TestDriveManageShares(t *testing.T) { b.driveSetSharesLocked(tt.existing) } if !tt.disabled { - nm := ptr.To(*b.currentNode().NetMap()) + nm := new(*b.currentNode().NetMap()) self := nm.SelfNode.AsStruct() self.CapMap = tailcfg.NodeCapMap{tailcfg.NodeAttrsTaildriveShare: nil} nm.SelfNode = self.View() @@ -4439,14 +4701,14 @@ func TestRoundTraffic(t *testing.T) { bytes int64 want float64 }{ - {name: "under 5 bytes", bytes: 4, want: 4}, - {name: "under 1000 bytes", bytes: 987, want: 990}, - {name: "under 10_000 bytes", bytes: 8875, want: 8900}, - {name: "under 100_000 bytes", bytes: 77777, want: 78000}, - {name: "under 1_000_000 bytes", bytes: 666523, want: 670000}, - {name: "under 10_000_000 bytes", bytes: 22556677, want: 23000000}, - {name: "under 1_000_000_000 bytes", bytes: 1234234234, want: 1200000000}, - {name: "under 1_000_000_000 bytes", bytes: 123423423499, want: 123400000000}, + {name: "under-5B", bytes: 4, want: 4}, + {name: "under-1000B", bytes: 987, want: 990}, + {name: "under-10000B", bytes: 8875, want: 8900}, + {name: "under-100000B", bytes: 77777, want: 78000}, + {name: "under-1000000B", bytes: 666523, want: 670000}, + {name: "under-10000000B", bytes: 22556677, want: 23000000}, + {name: "under-1000000000B", bytes: 1234234234, want: 1200000000}, + {name: "over-1000000000B", bytes: 123423423499, want: 123400000000}, } for _, tt := range tests { @@ -4476,7 +4738,7 @@ func makePeer(id tailcfg.NodeID, opts ...peerOptFunc) tailcfg.NodeView { DiscoKey: makeDiscoKeyFromID(id), StableID: tailcfg.StableNodeID(fmt.Sprintf("stable%d", id)), Name: fmt.Sprintf("peer%d", id), - Online: ptr.To(true), + Online: new(true), MachineAuthorized: true, HomeDERP: int(id), } @@ -4808,7 +5070,7 @@ func TestSuggestExitNode(t *testing.T) { wantError error }{ { - name: "2 exit nodes in same region", + name: "2-exits-same-region", lastReport: preferred1Report, netMap: &netmap.NetworkMap{ SelfNode: selfNode.View(), @@ -4826,7 +5088,7 @@ func TestSuggestExitNode(t *testing.T) { wantID: "stable1", }, { - name: "2 exit nodes different regions unknown latency", + name: "2-exits-different-regions-unknown-latency", lastReport: noLatency1Report, netMap: defaultNetmap, wantRegions: []int{1, 3}, // the only regions with peers @@ -4835,7 +5097,7 @@ func TestSuggestExitNode(t *testing.T) { wantID: "stable2", }, { - name: "2 derp based exit nodes, different regions, equal latency", + name: "2-derp-exits-different-regions-equal-latency", lastReport: &netcheck.Report{ RegionLatency: map[int]time.Duration{ 1: 10, @@ -4858,7 +5120,7 @@ func TestSuggestExitNode(t *testing.T) { wantID: "stable1", }, { - name: "mullvad nodes, no derp based exit nodes", + name: "mullvad-no-derp-exits", lastReport: noLatency1Report, netMap: locationNetmap, wantID: "stable5", @@ -4866,7 +5128,7 @@ func TestSuggestExitNode(t *testing.T) { wantName: "Dallas", }, { - name: "nearby mullvad nodes with different priorities", + name: "nearby-mullvad-different-priorities", lastReport: noLatency1Report, netMap: &netmap.NetworkMap{ SelfNode: selfNode.View(), @@ -4882,7 +5144,7 @@ func TestSuggestExitNode(t *testing.T) { wantName: "Fort Worth", }, { - name: "nearby mullvad nodes with same priorities", + name: "nearby-mullvad-same-priorities", lastReport: noLatency1Report, netMap: &netmap.NetworkMap{ SelfNode: selfNode.View(), @@ -4899,7 +5161,7 @@ func TestSuggestExitNode(t *testing.T) { wantName: "Dallas", }, { - name: "mullvad nodes, remaining node is not in preferred derp", + name: "mullvad-remaining-not-in-preferred-derp", lastReport: noLatency1Report, netMap: &netmap.NetworkMap{ SelfNode: selfNode.View(), @@ -4915,7 +5177,7 @@ func TestSuggestExitNode(t *testing.T) { wantName: "peer4", }, { - name: "no peers", + name: "no-peers", lastReport: noLatency1Report, netMap: &netmap.NetworkMap{ SelfNode: selfNode.View(), @@ -4923,13 +5185,13 @@ func TestSuggestExitNode(t *testing.T) { }, }, { - name: "nil report", + name: "nil-report", lastReport: nil, netMap: largeNetmap, wantError: ErrNoPreferredDERP, }, { - name: "no preferred derp region", + name: "no-preferred-derp-region", lastReport: preferredNoneReport, netMap: &netmap.NetworkMap{ SelfNode: selfNode.View(), @@ -4938,13 +5200,13 @@ func TestSuggestExitNode(t *testing.T) { wantError: ErrNoPreferredDERP, }, { - name: "nil netmap", + name: "nil-netmap", lastReport: noLatency1Report, netMap: nil, wantError: ErrNoPreferredDERP, }, { - name: "nil derpmap", + name: "nil-derpmap", lastReport: noLatency1Report, netMap: &netmap.NetworkMap{ SelfNode: selfNode.View(), @@ -4956,7 +5218,7 @@ func TestSuggestExitNode(t *testing.T) { wantError: ErrNoPreferredDERP, }, { - name: "missing suggestion capability", + name: "missing-suggestion-capability", lastReport: noLatency1Report, netMap: &netmap.NetworkMap{ SelfNode: selfNode.View(), @@ -4968,7 +5230,7 @@ func TestSuggestExitNode(t *testing.T) { }, }, { - name: "prefer last node", + name: "prefer-last-node", lastReport: preferred1Report, netMap: &netmap.NetworkMap{ SelfNode: selfNode.View(), @@ -4987,7 +5249,7 @@ func TestSuggestExitNode(t *testing.T) { wantID: "stable2", }, { - name: "found better derp node", + name: "found-better-derp-node", lastSuggestion: "stable3", lastReport: preferred1Report, netMap: defaultNetmap, @@ -4995,7 +5257,7 @@ func TestSuggestExitNode(t *testing.T) { wantName: "peer2", }, { - name: "prefer last mullvad node", + name: "prefer-last-mullvad-node", lastSuggestion: "stable2", lastReport: preferred1Report, netMap: &netmap.NetworkMap{ @@ -5013,7 +5275,7 @@ func TestSuggestExitNode(t *testing.T) { wantLocation: dallas.View(), }, { - name: "prefer better mullvad node", + name: "prefer-better-mullvad-node", lastSuggestion: "stable2", lastReport: preferred1Report, netMap: &netmap.NetworkMap{ @@ -5031,7 +5293,7 @@ func TestSuggestExitNode(t *testing.T) { wantLocation: fortWorth.View(), }, { - name: "large netmap", + name: "large-netmap", lastReport: preferred1Report, netMap: largeNetmap, wantNodes: []tailcfg.StableNodeID{"stable1", "stable2"}, @@ -5039,13 +5301,13 @@ func TestSuggestExitNode(t *testing.T) { wantName: "peer2", }, { - name: "no allowed suggestions", + name: "no-allowed-suggestions", lastReport: preferred1Report, netMap: largeNetmap, allowPolicy: []tailcfg.StableNodeID{}, }, { - name: "only derp suggestions", + name: "only-derp-suggestions", lastReport: preferred1Report, netMap: largeNetmap, allowPolicy: []tailcfg.StableNodeID{"stable1", "stable2", "stable3"}, @@ -5054,7 +5316,7 @@ func TestSuggestExitNode(t *testing.T) { wantName: "peer2", }, { - name: "only mullvad suggestions", + name: "only-mullvad-suggestions", lastReport: preferred1Report, netMap: largeNetmap, allowPolicy: []tailcfg.StableNodeID{"stable5", "stable6", "stable7"}, @@ -5063,7 +5325,7 @@ func TestSuggestExitNode(t *testing.T) { wantLocation: fortWorth.View(), }, { - name: "only worst derp", + name: "only-worst-derp", lastReport: preferred1Report, netMap: largeNetmap, allowPolicy: []tailcfg.StableNodeID{"stable3"}, @@ -5071,7 +5333,7 @@ func TestSuggestExitNode(t *testing.T) { wantName: "peer3", }, { - name: "only worst mullvad", + name: "only-worst-mullvad", lastReport: preferred1Report, netMap: largeNetmap, allowPolicy: []tailcfg.StableNodeID{"stable6"}, @@ -5081,7 +5343,7 @@ func TestSuggestExitNode(t *testing.T) { }, { // Regression test for https://github.com/tailscale/tailscale/issues/17661 - name: "exit nodes with no home DERP, randomly selected", + name: "exits-no-home-DERP-random-selection", lastReport: &netcheck.Report{ RegionLatency: map[int]time.Duration{ 1: 10, @@ -5163,7 +5425,7 @@ func TestSuggestExitNodePickWeighted(t *testing.T) { wantIDs []tailcfg.StableNodeID }{ { - name: "different priorities", + name: "different-priorities", candidates: []tailcfg.NodeView{ makePeer(2, withExitRoutes(), withLocation(location20.View())), makePeer(3, withExitRoutes(), withLocation(location10.View())), @@ -5171,7 +5433,7 @@ func TestSuggestExitNodePickWeighted(t *testing.T) { wantIDs: []tailcfg.StableNodeID{"stable2"}, }, { - name: "same priorities", + name: "same-priorities", candidates: []tailcfg.NodeView{ makePeer(2, withExitRoutes(), withLocation(location10.View())), makePeer(3, withExitRoutes(), withLocation(location10.View())), @@ -5179,11 +5441,11 @@ func TestSuggestExitNodePickWeighted(t *testing.T) { wantIDs: []tailcfg.StableNodeID{"stable2", "stable3"}, }, { - name: "<1 candidates", + name: "lt1-candidates", candidates: []tailcfg.NodeView{}, }, { - name: "1 candidate", + name: "1-candidate", candidates: []tailcfg.NodeView{ makePeer(2, withExitRoutes(), withLocation(location20.View())), }, @@ -5219,7 +5481,7 @@ func TestSuggestExitNodeLongLatDistance(t *testing.T) { want float64 }{ { - name: "zero values", + name: "zero-values", fromLat: 0, fromLong: 0, toLat: 0, @@ -5227,7 +5489,7 @@ func TestSuggestExitNodeLongLatDistance(t *testing.T) { want: 0, }, { - name: "valid values", + name: "valid-values", fromLat: 40.73061, fromLong: -73.935242, toLat: 37.3382082, @@ -5235,7 +5497,8 @@ func TestSuggestExitNodeLongLatDistance(t *testing.T) { want: 4117266.873301274, }, { - name: "valid values, locations in north and south of equator", + // Locations in north and south of equator. + name: "valid-values-cross-equator", fromLat: 40.73061, fromLong: -73.935242, toLat: -33.861481, @@ -5640,13 +5903,13 @@ func TestMinLatencyDERPregion(t *testing.T) { wantRegion int }{ { - name: "regions, no latency values", + name: "regions-no-latency", regions: []int{1, 2, 3}, wantRegion: 0, report: &netcheck.Report{}, }, { - name: "regions, different latency values", + name: "regions-different-latency", regions: []int{1, 2, 3}, wantRegion: 2, report: &netcheck.Report{ @@ -5658,7 +5921,7 @@ func TestMinLatencyDERPregion(t *testing.T) { }, }, { - name: "regions, same values", + name: "regions-same-latency", regions: []int{1, 2, 3}, wantRegion: 1, report: &netcheck.Report{ @@ -5805,7 +6068,7 @@ func TestFillAllowedSuggestions(t *testing.T) { want: []tailcfg.StableNodeID{"one", "three", "four", "two"}, // order should not matter }, { - name: "preserve case", + name: "preserve-case", allowPolicy: []string{"ABC", "def", "gHiJ"}, want: []tailcfg.StableNodeID{"ABC", "def", "gHiJ"}, }, @@ -5959,61 +6222,61 @@ func TestNotificationTargetMatch(t *testing.T) { wantMatch: false, }, { - name: "FilterByUID+CID/Nil", + name: "FilterByUID-CID/Nil", target: notificationTarget{userID: "S-1-5-21-1-2-3-4"}, actor: nil, wantMatch: false, }, { - name: "FilterByUID+CID/NoUID/NoCID", + name: "FilterByUID-CID/NoUID/NoCID", target: notificationTarget{userID: "S-1-5-21-1-2-3-4", clientID: ipnauth.ClientIDFrom("A")}, actor: &ipnauth.TestActor{}, wantMatch: false, }, { - name: "FilterByUID+CID/NoUID/SameCID", + name: "FilterByUID-CID/NoUID/SameCID", target: notificationTarget{userID: "S-1-5-21-1-2-3-4", clientID: ipnauth.ClientIDFrom("A")}, actor: &ipnauth.TestActor{CID: ipnauth.ClientIDFrom("A")}, wantMatch: false, }, { - name: "FilterByUID+CID/NoUID/DifferentCID", + name: "FilterByUID-CID/NoUID/DifferentCID", target: notificationTarget{userID: "S-1-5-21-1-2-3-4", clientID: ipnauth.ClientIDFrom("A")}, actor: &ipnauth.TestActor{CID: ipnauth.ClientIDFrom("B")}, wantMatch: false, }, { - name: "FilterByUID+CID/SameUID/NoCID", + name: "FilterByUID-CID/SameUID/NoCID", target: notificationTarget{userID: "S-1-5-21-1-2-3-4", clientID: ipnauth.ClientIDFrom("A")}, actor: &ipnauth.TestActor{UID: "S-1-5-21-1-2-3-4"}, wantMatch: false, }, { - name: "FilterByUID+CID/SameUID/SameCID", + name: "FilterByUID-CID/SameUID/SameCID", target: notificationTarget{userID: "S-1-5-21-1-2-3-4", clientID: ipnauth.ClientIDFrom("A")}, actor: &ipnauth.TestActor{UID: "S-1-5-21-1-2-3-4", CID: ipnauth.ClientIDFrom("A")}, wantMatch: true, }, { - name: "FilterByUID+CID/SameUID/DifferentCID", + name: "FilterByUID-CID/SameUID/DifferentCID", target: notificationTarget{userID: "S-1-5-21-1-2-3-4", clientID: ipnauth.ClientIDFrom("A")}, actor: &ipnauth.TestActor{UID: "S-1-5-21-1-2-3-4", CID: ipnauth.ClientIDFrom("B")}, wantMatch: false, }, { - name: "FilterByUID+CID/DifferentUID/NoCID", + name: "FilterByUID-CID/DifferentUID/NoCID", target: notificationTarget{userID: "S-1-5-21-1-2-3-4", clientID: ipnauth.ClientIDFrom("A")}, actor: &ipnauth.TestActor{UID: "S-1-5-21-5-6-7-8"}, wantMatch: false, }, { - name: "FilterByUID+CID/DifferentUID/SameCID", + name: "FilterByUID-CID/DifferentUID/SameCID", target: notificationTarget{userID: "S-1-5-21-1-2-3-4", clientID: ipnauth.ClientIDFrom("A")}, actor: &ipnauth.TestActor{UID: "S-1-5-21-5-6-7-8", CID: ipnauth.ClientIDFrom("A")}, wantMatch: false, }, { - name: "FilterByUID+CID/DifferentUID/DifferentCID", + name: "FilterByUID-CID/DifferentUID/DifferentCID", target: notificationTarget{userID: "S-1-5-21-1-2-3-4", clientID: ipnauth.ClientIDFrom("A")}, actor: &ipnauth.TestActor{UID: "S-1-5-21-5-6-7-8", CID: ipnauth.ClientIDFrom("B")}, wantMatch: false, @@ -6399,13 +6662,13 @@ func TestConfigFileReload(t *testing.T) { initial: &conffile.Config{ Parsed: ipn.ConfigVAlpha{ Version: "alpha0", - Hostname: ptr.To("initial-host"), + Hostname: new("initial-host"), }, }, updated: &conffile.Config{ Parsed: ipn.ConfigVAlpha{ Version: "alpha0", - Hostname: ptr.To("updated-host"), + Hostname: new("updated-host"), }, }, checkFn: func(t *testing.T, b *LocalBackend) { @@ -7287,6 +7550,98 @@ func TestDeps(t *testing.T) { }.Check(t) } +func TestOnClientVersionRespectsAutoUpdateCheck(t *testing.T) { + lb := newTestLocalBackend(t) + + cv := &tailcfg.ClientVersion{ + RunningLatest: false, + LatestVersion: "1.96.0", + } + + // With Check disabled, onClientVersion should cache but not broadcast. + lb.SetPrefsForTest(&ipn.Prefs{ + AutoUpdate: ipn.AutoUpdatePrefs{Check: false}, + }) + + nw := newNotificationWatcher(t, lb, ipnauth.Self) + nw.watch(0, nil, unexpectedClientVersion) + lb.onClientVersion(cv) + nw.check() + + // Verify it was cached despite not being broadcast. + lb.mu.Lock() + cached := lb.lastClientVersion + lb.mu.Unlock() + if cached == nil || cached.LatestVersion != "1.96.0" { + t.Fatalf("lastClientVersion not cached: got %v", cached) + } + + // With Check enabled, onClientVersion should broadcast. + lb.SetPrefsForTest(&ipn.Prefs{ + AutoUpdate: ipn.AutoUpdatePrefs{Check: true}, + }) + + nw.watch(0, []wantedNotification{ + wantClientVersionNotify("1.96.0"), + }) + lb.onClientVersion(cv) + nw.check() +} + +func TestWatchNotificationsInitialClientVersion(t *testing.T) { + lb := newTestLocalBackend(t) + + cv := &tailcfg.ClientVersion{ + RunningLatest: false, + LatestVersion: "1.96.0", + } + + // Set Check=true and cache a ClientVersion. + lb.SetPrefsForTest(&ipn.Prefs{ + AutoUpdate: ipn.AutoUpdatePrefs{Check: true}, + }) + lb.mu.Lock() + lb.lastClientVersion = cv + lb.mu.Unlock() + + // Watch with NotifyInitialClientVersion should include ClientVersion. + nw := newNotificationWatcher(t, lb, ipnauth.Self) + nw.watch(ipn.NotifyInitialClientVersion, []wantedNotification{ + wantClientVersionNotify("1.96.0"), + }) + nw.check() + + // Watch without the flag, should not include it. + nw2 := newNotificationWatcher(t, lb, ipnauth.Self) + nw2.watch(0, nil, unexpectedClientVersion) + nw2.check() + + // Watch with the flag but Check=false, should not include it. + lb.SetPrefsForTest(&ipn.Prefs{ + AutoUpdate: ipn.AutoUpdatePrefs{Check: false}, + }) + nw3 := newNotificationWatcher(t, lb, ipnauth.Self) + nw3.watch(ipn.NotifyInitialClientVersion, nil, unexpectedClientVersion) + nw3.check() +} + +func wantClientVersionNotify(wantLatest string) wantedNotification { + return wantedNotification{ + name: fmt.Sprintf("ClientVersion-%s", wantLatest), + cond: func(_ testing.TB, _ ipnauth.Actor, n *ipn.Notify) bool { + return n.ClientVersion != nil && n.ClientVersion.LatestVersion == wantLatest + }, + } +} + +func unexpectedClientVersion(t testing.TB, _ ipnauth.Actor, n *ipn.Notify) bool { + if n.ClientVersion != nil { + t.Errorf("unexpected ClientVersion: %v", n.ClientVersion) + return true + } + return false +} + func checkError(tb testing.TB, got, want error, fatal bool) { tb.Helper() f := tb.Errorf @@ -7362,28 +7717,28 @@ func TestStripKeysFromPrefs(t *testing.T) { genNotify := map[string]func() ipn.Notify{ "Notify.Prefs.Đļ.Persist.PrivateNodeKey": func() ipn.Notify { return ipn.Notify{ - Prefs: ptr.To((&ipn.Prefs{ + Prefs: new((&ipn.Prefs{ Persist: &persist.Persist{PrivateNodeKey: key.NewNode()}, }).View()), } }, "Notify.Prefs.Đļ.Persist.OldPrivateNodeKey": func() ipn.Notify { return ipn.Notify{ - Prefs: ptr.To((&ipn.Prefs{ + Prefs: new((&ipn.Prefs{ Persist: &persist.Persist{OldPrivateNodeKey: key.NewNode()}, }).View()), } }, "Notify.Prefs.Đļ.Persist.NetworkLockKey": func() ipn.Notify { return ipn.Notify{ - Prefs: ptr.To((&ipn.Prefs{ + Prefs: new((&ipn.Prefs{ Persist: &persist.Persist{NetworkLockKey: key.NewNLPrivate()}, }).View()), } }, "Notify.Prefs.Đļ.Persist.AttestationKey": func() ipn.Notify { return ipn.Notify{ - Prefs: ptr.To((&ipn.Prefs{ + Prefs: new((&ipn.Prefs{ Persist: &persist.Persist{AttestationKey: new(fakeAttestationKey)}, }).View()), } @@ -7408,6 +7763,7 @@ func TestStripKeysFromPrefs(t *testing.T) { ch := make(chan *ipn.Notify, 1) b := &LocalBackend{ extHost: h, + health: health.NewTracker(eventbustest.NewBus(t)), notifyWatchers: map[string]*watchSession{ "test": {ch: ch}, }, @@ -7570,3 +7926,235 @@ func TestRouteAllDisabled(t *testing.T) { }) } } + +// TestAdvertiseRoute_InvalidPrefix tests that AdvertiseRoute rejects routes +// with non-address bits set in the prefix. +func TestAdvertiseRoute_InvalidPrefix(t *testing.T) { + b := newTestLocalBackend(t) + + tests := []struct { + name string + routes []netip.Prefix + wantErr bool + }{ + { + name: "valid_routes", + routes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("2001:db8::/32"), + }, + wantErr: false, + }, + { + name: "invalid_ipv4_route", + routes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.1/24"), // has non-address bits + }, + wantErr: true, + }, + { + name: "invalid_ipv6_route", + routes: []netip.Prefix{ + netip.MustParsePrefix("2a01:4f9:c010:c015::1/64"), // has non-address bits + }, + wantErr: true, + }, + { + name: "mixed_valid_and_invalid", + routes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), // valid + netip.MustParsePrefix("192.168.1.1/16"), // invalid - this should cause rejection + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := b.AdvertiseRoute(tt.routes...) + if (err != nil) != tt.wantErr { + t.Errorf("AdvertiseRoute() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +// TestEditPrefs_InvalidAdvertiseRoutes tests that EditPrefs (used by the local +// API) rejects routes with non-address bits set. +func TestEditPrefs_InvalidAdvertiseRoutes(t *testing.T) { + b := newTestLocalBackend(t) + + tests := []struct { + name string + routes []netip.Prefix + wantErr bool + }{ + { + name: "valid_routes", + routes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("2001:db8::/32"), + }, + wantErr: false, + }, + { + name: "invalid_ipv4_route", + routes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.1/24"), // has non-address bits + }, + wantErr: true, + }, + { + name: "invalid_ipv6_route", + routes: []netip.Prefix{ + netip.MustParsePrefix("fdf2:8bc1:6276:4f3f:dc33:c4ff:fe0b:120a/64"), // has non-address bits + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mp := &ipn.MaskedPrefs{ + Prefs: ipn.Prefs{ + AdvertiseRoutes: tt.routes, + }, + AdvertiseRoutesSet: true, + } + + _, err := b.EditPrefs(mp) + if (err != nil) != tt.wantErr { + t.Errorf("EditPrefs() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestNoSNATWithAdvertisedExitNodeWarning(t *testing.T) { + exitRoutes := []netip.Prefix{ + netip.MustParsePrefix("0.0.0.0/0"), + netip.MustParsePrefix("::/0"), + } + warnCode := health.WarnableCode("nosnat-with-advertised-exit-node") + + tests := []struct { + name string + prefs *ipn.Prefs + wantWarning bool + }{ + { + name: "no-snat-without-exit-node", + prefs: &ipn.Prefs{ + NoSNAT: true, + AdvertiseRoutes: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}, + }, + wantWarning: false, + }, + { + name: "snat-enabled-with-exit-node", + prefs: &ipn.Prefs{ + NoSNAT: false, + AdvertiseRoutes: exitRoutes, + }, + wantWarning: false, + }, + { + name: "no-snat-with-exit-node", + prefs: &ipn.Prefs{ + NoSNAT: true, + AdvertiseRoutes: exitRoutes, + }, + wantWarning: true, + }, + { + name: "no-snat-with-exit-node-and-subnet", + prefs: &ipn.Prefs{ + NoSNAT: true, + AdvertiseRoutes: append([]netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, exitRoutes...), + }, + wantWarning: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + b := newTestLocalBackend(t) + b.SetPrefsForTest(tt.prefs) + _, hasWarning := b.HealthTracker().CurrentState().Warnings[warnCode] + if hasWarning != tt.wantWarning { + t.Errorf("warning present = %v, want %v", hasWarning, tt.wantWarning) + } + }) + } + + // Verify that the warning clears when the conflicting combination is resolved. + t.Run("warning-clears-on-fix", func(t *testing.T) { + b := newTestLocalBackend(t) + b.SetPrefsForTest(&ipn.Prefs{NoSNAT: true, AdvertiseRoutes: exitRoutes}) + if _, ok := b.HealthTracker().CurrentState().Warnings[warnCode]; !ok { + t.Fatal("expected warning to be set") + } + b.SetPrefsForTest(&ipn.Prefs{NoSNAT: false, AdvertiseRoutes: exitRoutes}) + if _, ok := b.HealthTracker().CurrentState().Warnings[warnCode]; ok { + t.Fatal("expected warning to be cleared after enabling SNAT") + } + }) +} + +// TestStartPreservesLoginFlags is a regression test for a bug where the +// LoginEphemeral flag stored on LocalBackend was silently dropped by the +// auto-login paths in Start() and setPrefsLocked(). The user-visible symptom +// was tsnet.Server.Ephemeral=true being ignored when combined with an auth +// key, because the resulting RegisterRequest.Ephemeral was false. +// +// The test manually constructs the LocalBackend to be able set +// loginFlags=LoginEphemeral, and then checks that at least one cc.Login call +// carried the LoginEphemeral bit. +func TestStartPreservesLoginFlags(t *testing.T) { + logf := tstest.WhileTestRunningLogger(t) + sys := tsd.NewSystem() + sys.Set(new(mem.Store)) + e, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker.Get(), sys.UserMetricsRegistry(), sys.Bus.Get()) + if err != nil { + t.Fatalf("NewFakeUserspaceEngine: %v", err) + } + t.Cleanup(e.Close) + sys.Set(e) + + b, err := NewLocalBackend(logf, logid.PublicID{}, sys, controlclient.LoginEphemeral) + if err != nil { + t.Fatalf("NewLocalBackend: %v", err) + } + t.Cleanup(b.Shutdown) + + var cc *mockControl + b.SetControlClientGetterForTesting(func(opts controlclient.Options) (controlclient.Client, error) { + cc = newClient(t, opts) + return cc, nil + }) + + if err := b.Start(ipn.Options{ + UpdatePrefs: &ipn.Prefs{ + ControlURL: "https://controlplane.example.com", + WantRunning: false, + }, + AuthKey: "tskey-auth-test", + }); err != nil { + t.Fatalf("Start: %v", err) + } + + if _, err := b.EditPrefs(&ipn.MaskedPrefs{ + Prefs: ipn.Prefs{WantRunning: true}, + WantRunningSet: true, + }); err != nil { + t.Fatalf("EditPrefs: %v", err) + } + + cc.mu.Lock() + flags := cc.loginFlags + cc.mu.Unlock() + if flags&controlclient.LoginEphemeral == 0 { + t.Errorf("cc.Login was never called with LoginEphemeral; got flags=%v", flags) + } +} diff --git a/ipn/ipnlocal/netmapcache/netmapcache_test.go b/ipn/ipnlocal/netmapcache/netmapcache_test.go index b5a46d2982a04..ca66a17133a5e 100644 --- a/ipn/ipnlocal/netmapcache/netmapcache_test.go +++ b/ipn/ipnlocal/netmapcache/netmapcache_test.go @@ -275,7 +275,7 @@ var skippedMapFields = []string{ func checkFieldCoverage(t *testing.T, nm *netmap.NetworkMap) { t.Helper() - mt := reflect.TypeOf(nm).Elem() + mt := reflect.TypeFor[netmap.NetworkMap]() mv := reflect.ValueOf(nm).Elem() for i := 0; i < mt.NumField(); i++ { f := mt.Field(i) diff --git a/ipn/ipnlocal/netstack.go b/ipn/ipnlocal/netstack.go index b331d93e329de..eac9568b7f765 100644 --- a/ipn/ipnlocal/netstack.go +++ b/ipn/ipnlocal/netstack.go @@ -11,7 +11,6 @@ import ( "time" "gvisor.dev/gvisor/pkg/tcpip" - "tailscale.com/types/ptr" ) // TCPHandlerForDst returns a TCP handler for connections to dst, or nil if @@ -52,7 +51,7 @@ func (b *LocalBackend) TCPHandlerForDst(src, dst netip.AddrPort) (handler func(c // tell the difference between a long lived connection that is idle // vs a connection that is dead because the peer has gone away. // We pick 72h as that is typically sufficient for a long weekend. - opts = append(opts, ptr.To(tcpip.KeepaliveIdleOption(72*time.Hour))) + opts = append(opts, new(tcpip.KeepaliveIdleOption(72*time.Hour))) return b.handleSSHConn, opts } // TODO(will,sonia): allow customizing web client port ? diff --git a/ipn/ipnlocal/network-lock.go b/ipn/ipnlocal/network-lock.go index 242fec0287c65..75d5d95114162 100644 --- a/ipn/ipnlocal/network-lock.go +++ b/ipn/ipnlocal/network-lock.go @@ -27,6 +27,7 @@ import ( "tailscale.com/health/healthmsg" "tailscale.com/ipn" "tailscale.com/ipn/ipnstate" + "tailscale.com/ipn/store/mem" "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" "tailscale.com/tka" @@ -38,20 +39,23 @@ import ( "tailscale.com/types/tkatype" "tailscale.com/util/mak" "tailscale.com/util/set" + "tailscale.com/util/testenv" ) // TODO(tom): RPC retry/backoff was broken and has been removed. Fix? var ( errMissingNetmap = errors.New("missing netmap: verify that you are logged in") - errNetworkLockNotActive = errors.New("network-lock is not active") - - tkaCompactionDefaults = tka.CompactionOptions{ - MinChain: 24, // Keep at minimum 24 AUMs since head. - MinAge: 14 * 24 * time.Hour, // Keep 2 weeks of AUMs. - } + errNetworkLockNotActive = errors.New("tailnet-lock is not active") ) +// IsNetworkLockNotActive reports whether the given error indicates that +// network-lock is not active. Stop-gap for feature/tailnetlock to check this +// until all of this is code is moved to the feature. +func IsNetworkLockNotActive(err error) bool { + return errors.Is(err, errNetworkLockNotActive) +} + type tkaState struct { profile ipn.ProfileID authority *tka.Authority @@ -92,16 +96,12 @@ func (b *LocalBackend) initTKALocked() error { return fmt.Errorf("initializing tka: %v", err) } - if err := authority.Compact(storage, tkaCompactionDefaults); err != nil { - b.logf("tka compaction failed: %v", err) - } - b.tka = &tkaState{ profile: cp.ID(), authority: authority, storage: storage, } - b.logf("tka initialized at head %x", authority.Head()) + b.logf("tka initialized at head %s", authority.Head()) } return nil @@ -304,7 +304,11 @@ func (b *LocalBackend) tkaSyncIfNeeded(nm *netmap.NetworkMap, prefs ipn.PrefsVie wantEnabled := nm.TKAEnabled if isEnabled || wantEnabled { - b.logf("tkaSyncIfNeeded: isEnabled=%t, wantEnabled=%t, head=%v", isEnabled, wantEnabled, nm.TKAHead) + nodeHead := "" + if b.tka != nil { + nodeHead = b.tka.authority.Head().String() + } + b.logf("tkaSyncIfNeeded: isEnabled=%t, wantEnabled=%t, nodeHead=%v, netmapHead=%v", isEnabled, wantEnabled, nodeHead, nm.TKAHead) } ourNodeKey, ok := prefs.Persist().PublicNodeKeyOK() @@ -360,7 +364,7 @@ func (b *LocalBackend) tkaSyncIfNeeded(nm *netmap.NetworkMap, prefs ipn.PrefsVie // // We run this on every sync so that clients compact consistently. In many // cases this will be a no-op. - if err := b.tka.authority.Compact(b.tka.storage, tkaCompactionDefaults); err != nil { + if err := b.tka.authority.Compact(b.tka.storage, tka.CompactionDefaults); err != nil { return fmt.Errorf("tka compact: %w", err) } } @@ -407,7 +411,7 @@ func (b *LocalBackend) tkaSyncLocked(ourNodeKey key.NodePublic) error { // has updates for us, or we have updates for the control plane. // // TODO(tom): Do we want to keep processing even if the Inform fails? Need - // to think through if theres holdback concerns here or not. + // to think through if there's holdback concerns here or not. if len(offerResp.MissingAUMs) > 0 { aums := make([]tka.AUM, len(offerResp.MissingAUMs)) for i, a := range offerResp.MissingAUMs { @@ -654,15 +658,10 @@ func (b *LocalBackend) NetworkLockInit(keys []tka.Key, disablementValues [][]byt // the filesystem until we've finished the initialization sequence, // just in case something goes wrong. _, genesisAUM, err := tka.Create(tka.ChonkMem(), tka.State{ - Keys: keys, - // TODO(tom): s/tka.State.DisablementSecrets/tka.State.DisablementValues - // This will center on consistent nomenclature: - // - DisablementSecret: value needed to disable. - // - DisablementValue: the KDF of the disablement secret, a public value. - DisablementSecrets: disablementValues, - - StateID1: binary.LittleEndian.Uint64(entropy[:8]), - StateID2: binary.LittleEndian.Uint64(entropy[8:]), + Keys: keys, + DisablementValues: disablementValues, + StateID1: binary.LittleEndian.Uint64(entropy[:8]), + StateID2: binary.LittleEndian.Uint64(entropy[8:]), }, nlPriv) if err != nil { return fmt.Errorf("tka.Create: %v", err) @@ -708,6 +707,7 @@ func (b *LocalBackend) NetworkLockAllowed() bool { // Only use is in tests. func (b *LocalBackend) NetworkLockVerifySignatureForTest(nks tkatype.MarshaledSignature, nodeKey key.NodePublic) error { + testenv.AssertInTest() b.mu.Lock() defer b.mu.Unlock() if b.tka == nil { @@ -718,6 +718,7 @@ func (b *LocalBackend) NetworkLockVerifySignatureForTest(nks tkatype.MarshaledSi // Only use is in tests. func (b *LocalBackend) NetworkLockKeyTrustedForTest(keyID tkatype.KeyID) bool { + testenv.AssertInTest() b.mu.Lock() defer b.mu.Unlock() if b.tka == nil { @@ -806,7 +807,7 @@ func (b *LocalBackend) NetworkLockSign(nodeKey key.NodePublic, rotationPublic [] func (b *LocalBackend) NetworkLockModify(addKeys, removeKeys []tka.Key) (err error) { defer func() { if err != nil { - err = fmt.Errorf("modify network-lock keys: %w", err) + err = fmt.Errorf("modify tailnet-lock keys: %w", err) } }() @@ -1126,7 +1127,7 @@ func (b *LocalBackend) NetworkLockWrapPreauthKey(preauthKey string, tkaKey key.N return "", fmt.Errorf("signing failed: %w", err) } - b.logf("Generated network-lock credential signature using %s", tkaKey.Public().CLIString()) + b.logf("Generated tailnet-lock credential signature using %s", tkaKey.Public().CLIString()) return fmt.Sprintf("%s--TL%s-%s", preauthKey, tkaSuffixEncoder.EncodeToString(sig.Serialize()), tkaSuffixEncoder.EncodeToString(priv)), nil } @@ -1487,3 +1488,24 @@ func (b *LocalBackend) tkaReadAffectedSigs(ourNodeKey key.NodePublic, key tkatyp return a, nil } + +// LocalBackendWithTKAForTest creates a LocalBackend with an initialized TKA +// state for testing tailnet lock from the feature/tailnetlock package. Will be +// removed when tailnet lock is fully moved to its own package. Do not use this +// from any other package. +func LocalBackendWithTKAForTest(chonk tka.CompactableChonk, tka *tka.Authority) *LocalBackend { + testenv.AssertInTest() + + var state *tkaState + if tka != nil { + state = &tkaState{ + authority: tka, + storage: chonk, + } + } + return &LocalBackend{ + store: &mem.Store{}, + logf: logger.Discard, + tka: state, + } +} diff --git a/ipn/ipnlocal/network-lock_test.go b/ipn/ipnlocal/network-lock_test.go index 8aa0a877b8dd3..eead2d8926f27 100644 --- a/ipn/ipnlocal/network-lock_test.go +++ b/ipn/ipnlocal/network-lock_test.go @@ -104,8 +104,8 @@ func TestTKAEnablementFlow(t *testing.T) { key := tka.Key{Kind: tka.Key25519, Public: nlPriv.Public().Verifier(), Votes: 2} chonk := tka.ChonkMem() a1, genesisAUM, err := tka.Create(chonk, tka.State{ - Keys: []tka.Key{key}, - DisablementSecrets: [][]byte{bytes.Repeat([]byte{0xa5}, 32)}, + Keys: []tka.Key{key}, + DisablementValues: [][]byte{bytes.Repeat([]byte{0xa5}, 32)}, }, nlPriv) if err != nil { t.Fatalf("tka.Create() failed: %v", err) @@ -158,6 +158,7 @@ func TestTKAEnablementFlow(t *testing.T) { cc: cc, ccAuto: cc, logf: t.Logf, + health: health.NewTracker(eventbustest.NewBus(t)), pm: pm, store: pm.Store(), } @@ -195,8 +196,8 @@ func TestTKADisablementFlow(t *testing.T) { t.Fatal(err) } authority, _, err := tka.Create(chonk, tka.State{ - Keys: []tka.Key{key}, - DisablementSecrets: [][]byte{tka.DisablementKDF(disablementSecret)}, + Keys: []tka.Key{key}, + DisablementValues: [][]byte{tka.DisablementKDF(disablementSecret)}, }, nlPriv) if err != nil { t.Fatalf("tka.Create() failed: %v", err) @@ -244,6 +245,7 @@ func TestTKADisablementFlow(t *testing.T) { cc: cc, ccAuto: cc, logf: t.Logf, + health: health.NewTracker(eventbustest.NewBus(t)), tka: &tkaState{ authority: authority, storage: chonk, @@ -300,9 +302,9 @@ func TestTKASync(t *testing.T) { } tcs := []tkaSyncScenario{ - {name: "up to date"}, + {name: "up-to-date"}, { - name: "control has an update", + name: "control-has-an-update", controlAUMs: func(t *testing.T, a *tka.Authority, storage tka.Chonk, signer tka.Signer) []tka.AUM { b := a.NewUpdater(signer) if err := b.RemoveKey(someKey.MustID()); err != nil { @@ -317,7 +319,7 @@ func TestTKASync(t *testing.T) { }, { // AKA 'control data loss' scenario - name: "node has an update", + name: "node-has-an-update", nodeAUMs: func(t *testing.T, a *tka.Authority, storage tka.Chonk, signer tka.Signer) []tka.AUM { b := a.NewUpdater(signer) if err := b.RemoveKey(someKey.MustID()); err != nil { @@ -332,7 +334,7 @@ func TestTKASync(t *testing.T) { }, { // AKA 'control data loss + update in the meantime' scenario - name: "node and control diverge", + name: "node-and-control-diverge", controlAUMs: func(t *testing.T, a *tka.Authority, storage tka.Chonk, signer tka.Signer) []tka.AUM { b := a.NewUpdater(signer) if err := b.SetKeyMeta(someKey.MustID(), map[string]string{"ye": "swiggity"}); err != nil { @@ -368,8 +370,8 @@ func TestTKASync(t *testing.T) { key := tka.Key{Kind: tka.Key25519, Public: nlPriv.Public().Verifier(), Votes: 2} controlStorage := tka.ChonkMem() controlAuthority, bootstrap, err := tka.Create(controlStorage, tka.State{ - Keys: []tka.Key{key, someKey}, - DisablementSecrets: [][]byte{tka.DisablementKDF(disablementSecret)}, + Keys: []tka.Key{key, someKey}, + DisablementValues: [][]byte{tka.DisablementKDF(disablementSecret)}, }, nlPriv) if err != nil { t.Fatalf("tka.Create() failed: %v", err) @@ -428,6 +430,7 @@ func TestTKASync(t *testing.T) { cc: cc, ccAuto: cc, logf: t.Logf, + health: health.NewTracker(eventbustest.NewBus(t)), pm: pm, store: pm.Store(), tka: &tkaState{ @@ -478,8 +481,8 @@ func TestTKASyncTriggersCompact(t *testing.T) { controlStorage := tka.ChonkMem() controlStorage.SetClock(clock) controlAuthority, bootstrap, err := tka.Create(controlStorage, tka.State{ - Keys: []tka.Key{key, someKey}, - DisablementSecrets: [][]byte{tka.DisablementKDF(disablementSecret)}, + Keys: []tka.Key{key, someKey}, + DisablementValues: [][]byte{tka.DisablementKDF(disablementSecret)}, }, nlPriv) if err != nil { t.Fatalf("tka.Create() failed: %v", err) @@ -544,6 +547,7 @@ func TestTKASyncTriggersCompact(t *testing.T) { cc: cc, ccAuto: cc, logf: t.Logf, + health: health.NewTracker(eventbustest.NewBus(t)), pm: pm, store: pm.Store(), tka: &tkaState{ @@ -608,16 +612,17 @@ func TestTKAFilterNetmap(t *testing.T) { nlKey := tka.Key{Kind: tka.Key25519, Public: nlPriv.Public().Verifier(), Votes: 2} storage := tka.ChonkMem() authority, _, err := tka.Create(storage, tka.State{ - Keys: []tka.Key{nlKey}, - DisablementSecrets: [][]byte{bytes.Repeat([]byte{0xa5}, 32)}, + Keys: []tka.Key{nlKey}, + DisablementValues: [][]byte{bytes.Repeat([]byte{0xa5}, 32)}, }, nlPriv) if err != nil { t.Fatalf("tka.Create() failed: %v", err) } b := &LocalBackend{ - logf: t.Logf, - tka: &tkaState{authority: authority}, + logf: t.Logf, + health: health.NewTracker(eventbustest.NewBus(t)), + tka: &tkaState{authority: authority}, } n1, n2, n3, n4, n5 := key.NewNode(), key.NewNode(), key.NewNode(), key.NewNode(), key.NewNode() @@ -771,8 +776,8 @@ func TestTKADisable(t *testing.T) { t.Fatal(err) } authority, _, err := tka.Create(chonk, tka.State{ - Keys: []tka.Key{key}, - DisablementSecrets: [][]byte{tka.DisablementKDF(disablementSecret)}, + Keys: []tka.Key{key}, + DisablementValues: [][]byte{tka.DisablementKDF(disablementSecret)}, }, nlPriv) if err != nil { t.Fatalf("tka.Create() failed: %v", err) @@ -822,6 +827,7 @@ func TestTKADisable(t *testing.T) { cc: cc, ccAuto: cc, logf: t.Logf, + health: health.NewTracker(eventbustest.NewBus(t)), tka: &tkaState{ profile: pm.CurrentProfile().ID(), authority: authority, @@ -859,8 +865,8 @@ func TestTKASign(t *testing.T) { t.Fatal(err) } authority, _, err := tka.Create(chonk, tka.State{ - Keys: []tka.Key{key}, - DisablementSecrets: [][]byte{tka.DisablementKDF(disablementSecret)}, + Keys: []tka.Key{key}, + DisablementValues: [][]byte{tka.DisablementKDF(disablementSecret)}, }, nlPriv) if err != nil { t.Fatalf("tka.Create() failed: %v", err) @@ -887,6 +893,7 @@ func TestTKASign(t *testing.T) { cc: cc, ccAuto: cc, logf: t.Logf, + health: health.NewTracker(eventbustest.NewBus(t)), tka: &tkaState{ authority: authority, storage: chonk, @@ -918,8 +925,8 @@ func TestTKAForceDisable(t *testing.T) { t.Fatal(err) } authority, genesis, err := tka.Create(chonk, tka.State{ - Keys: []tka.Key{key}, - DisablementSecrets: [][]byte{tka.DisablementKDF(disablementSecret)}, + Keys: []tka.Key{key}, + DisablementValues: [][]byte{tka.DisablementKDF(disablementSecret)}, }, nlPriv) if err != nil { t.Fatalf("tka.Create() failed: %v", err) @@ -1006,8 +1013,8 @@ func TestTKAAffectedSigs(t *testing.T) { t.Fatal(err) } authority, _, err := tka.Create(chonk, tka.State{ - Keys: []tka.Key{tkaKey}, - DisablementSecrets: [][]byte{tka.DisablementKDF(disablementSecret)}, + Keys: []tka.Key{tkaKey}, + DisablementValues: [][]byte{tka.DisablementKDF(disablementSecret)}, }, nlPriv) if err != nil { t.Fatalf("tka.Create() failed: %v", err) @@ -1020,7 +1027,7 @@ func TestTKAAffectedSigs(t *testing.T) { wantErr string }{ { - "no error", + "no-error", func() *tka.NodeKeySignature { sig, _ := signNodeKey(tailcfg.TKASignInfo{NodePublic: nodePriv.Public()}, nlPriv) return sig @@ -1028,7 +1035,7 @@ func TestTKAAffectedSigs(t *testing.T) { "", }, { - "signature for different keyID", + "signature-for-different-keyID", func() *tka.NodeKeySignature { sig, _ := signNodeKey(tailcfg.TKASignInfo{NodePublic: nodePriv.Public()}, untrustedKey) return sig @@ -1036,7 +1043,7 @@ func TestTKAAffectedSigs(t *testing.T) { fmt.Sprintf("got signature with keyID %X from request for %X", untrustedKey.KeyID(), nlPriv.KeyID()), }, { - "invalid signature", + "invalid-signature", func() *tka.NodeKeySignature { sig, _ := signNodeKey(tailcfg.TKASignInfo{NodePublic: nodePriv.Public()}, nlPriv) copy(sig.Signature, []byte{1, 2, 3, 4, 5, 6}) // overwrite with trash to invalid signature @@ -1083,6 +1090,7 @@ func TestTKAAffectedSigs(t *testing.T) { cc: cc, ccAuto: cc, logf: t.Logf, + health: health.NewTracker(eventbustest.NewBus(t)), tka: &tkaState{ authority: authority, storage: chonk, @@ -1135,8 +1143,8 @@ func TestTKARecoverCompromisedKeyFlow(t *testing.T) { t.Fatal(err) } authority, _, err := tka.Create(chonk, tka.State{ - Keys: []tka.Key{key, compromisedKey, cosignKey}, - DisablementSecrets: [][]byte{tka.DisablementKDF(disablementSecret)}, + Keys: []tka.Key{key, compromisedKey, cosignKey}, + DisablementValues: [][]byte{tka.DisablementKDF(disablementSecret)}, }, nlPriv) if err != nil { t.Fatalf("tka.Create() failed: %v", err) @@ -1168,6 +1176,7 @@ func TestTKARecoverCompromisedKeyFlow(t *testing.T) { cc: cc, ccAuto: cc, logf: t.Logf, + health: health.NewTracker(eventbustest.NewBus(t)), tka: &tkaState{ authority: authority, storage: chonk, @@ -1187,6 +1196,7 @@ func TestTKARecoverCompromisedKeyFlow(t *testing.T) { b := LocalBackend{ varRoot: temp, logf: t.Logf, + health: health.NewTracker(eventbustest.NewBus(t)), tka: &tkaState{ authority: authority, storage: chonk, diff --git a/ipn/ipnlocal/node_backend.go b/ipn/ipnlocal/node_backend.go index b70d71cb934f2..f8579900df139 100644 --- a/ipn/ipnlocal/node_backend.go +++ b/ipn/ipnlocal/node_backend.go @@ -24,12 +24,12 @@ import ( "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/netmap" - "tailscale.com/types/ptr" "tailscale.com/types/views" "tailscale.com/util/dnsname" "tailscale.com/util/eventbus" "tailscale.com/util/mak" "tailscale.com/util/slicesx" + "tailscale.com/util/testenv" "tailscale.com/wgengine/filter" ) @@ -80,6 +80,13 @@ type nodeBackend struct { eventClient *eventbus.Client derpMapViewPub *eventbus.Publisher[tailcfg.DERPMapView] + // homeDERP lives here temporarily. as long as mapSession is short lived, we + // don't have a location delivering netmaps to local backend that knows our + // homeDERP hence why it is cached here for now. + // TODO(cmol): move this field into a refactored mapSession that is not + // short lived. + homeDERP atomic.Int64 + // TODO(nickkhyl): maybe use sync.RWMutex? mu syncs.Mutex // protects the following fields @@ -104,6 +111,16 @@ type nodeBackend struct { // nodeByAddr maps nodes' own addresses (excluding subnet routes) to node IDs. // It is mutated in place (with mu held) and must not escape the [nodeBackend]. nodeByAddr map[netip.Addr]tailcfg.NodeID + + // nodeByKey is an index of node public key to node ID for fast lookups. + // It is mutated in place (with mu held) and must not escape the [nodeBackend]. + nodeByKey map[key.NodePublic]tailcfg.NodeID + + // keyWaitersForTest is the test-only registry of channels waiting for + // a given peer key to first appear in the netmap. See + // [nodeBackend.AwaitNodeKeyForTest]. It is populated lazily and remains + // nil in production, where no test installs a waiter. + keyWaitersForTest map[key.NodePublic]chan struct{} } func newNodeBackend(ctx context.Context, logf logger.Logf, bus *eventbus.Bus) *nodeBackend { @@ -193,19 +210,8 @@ func (nb *nodeBackend) NodeByAddr(ip netip.Addr) (_ tailcfg.NodeID, ok bool) { func (nb *nodeBackend) NodeByKey(k key.NodePublic) (_ tailcfg.NodeID, ok bool) { nb.mu.Lock() defer nb.mu.Unlock() - if nb.netMap == nil { - return 0, false - } - if self := nb.netMap.SelfNode; self.Valid() && self.Key() == k { - return self.ID(), true - } - // TODO(bradfitz,nickkhyl): add nodeByKey like nodeByAddr instead of walking peers. - for _, n := range nb.peers { - if n.Key() == k { - return n.ID(), true - } - } - return 0, false + nid, ok := nb.nodeByKey[k] + return nid, ok } func (nb *nodeBackend) NodeByID(id tailcfg.NodeID) (_ tailcfg.NodeView, ok bool) { @@ -414,7 +420,7 @@ func (nb *nodeBackend) netMapWithPeers() *netmap.NetworkMap { if nb.netMap == nil { return nil } - nm := ptr.To(*nb.netMap) // shallow clone + nm := new(*nb.netMap) // shallow clone nm.Peers = slicesx.MapValues(nb.peers) slices.SortFunc(nm.Peers, func(a, b tailcfg.NodeView) int { return cmp.Compare(a.ID(), b.ID()) @@ -427,7 +433,9 @@ func (nb *nodeBackend) SetNetMap(nm *netmap.NetworkMap) { defer nb.mu.Unlock() nb.netMap = nm nb.updateNodeByAddrLocked() + nb.updateNodeByKeyLocked() nb.updatePeersLocked() + nb.signalKeyWaitersForTestLocked() if nm != nil { nb.derpMapViewPub.Publish(nm.DERPMap.View()) } else { @@ -435,6 +443,43 @@ func (nb *nodeBackend) SetNetMap(nm *netmap.NetworkMap) { } } +// AwaitNodeKeyForTest returns a channel that is closed once a peer with the +// given node key first appears in this nodeBackend's peer index, or +// immediately (a closed channel) if it's already present. It is intended for +// in-process benchmarks that drive synthetic netmap deltas and need a +// zero-overhead signal that the client has applied a delta, replacing +// poll-based [local.Client.WhoIsNodeKey] loops in tests. It panics outside +// of tests. +func (nb *nodeBackend) AwaitNodeKeyForTest(k key.NodePublic) <-chan struct{} { + testenv.AssertInTest() + nb.mu.Lock() + defer nb.mu.Unlock() + if _, ok := nb.nodeByKey[k]; ok { + return syncs.ClosedChan() + } + if ch, ok := nb.keyWaitersForTest[k]; ok { + return ch + } + ch := make(chan struct{}) + mak.Set(&nb.keyWaitersForTest, k, ch) + return ch +} + +// signalKeyWaitersForTestLocked closes any waiter channels whose keys now +// appear in nb.nodeByKey. It is cheap when there are no waiters, which is +// the common case in production. It is called from [nodeBackend.SetNetMap] +// after the per-key index has been rebuilt. +// +// Caller must hold nb.mu. +func (nb *nodeBackend) signalKeyWaitersForTestLocked() { + for k, ch := range nb.keyWaitersForTest { + if _, ok := nb.nodeByKey[k]; ok { + close(ch) + delete(nb.keyWaitersForTest, k) + } + } +} + func (nb *nodeBackend) updateNodeByAddrLocked() { nm := nb.netMap if nm == nil { @@ -471,6 +516,37 @@ func (nb *nodeBackend) updateNodeByAddrLocked() { } } +func (nb *nodeBackend) updateNodeByKeyLocked() { + nm := nb.netMap + if nm == nil { + nb.nodeByKey = nil + return + } + + if nb.nodeByKey == nil { + nb.nodeByKey = map[key.NodePublic]tailcfg.NodeID{} + } + // First pass, mark everything unwanted. + for k := range nb.nodeByKey { + nb.nodeByKey[k] = 0 + } + addNode := func(n tailcfg.NodeView) { + nb.nodeByKey[n.Key()] = n.ID() + } + if nm.SelfNode.Valid() { + addNode(nm.SelfNode) + } + for _, p := range nm.Peers { + addNode(p) + } + // Third pass, actually delete the unwanted items. + for k, v := range nb.nodeByKey { + if v == 0 { + delete(nb.nodeByKey, k) + } + } +} + func (nb *nodeBackend) updatePeersLocked() { nm := nb.netMap if nm == nil { @@ -840,7 +916,7 @@ func dnsConfigForNetmap(nm *netmap.NetworkMap, peers map[tailcfg.NodeID]tailcfg. addSplitDNSRoutes(nm.DNS.Routes) // Add split DNS routes for conn25 - conn25DNSTargets := appc.PickSplitDNSPeers(nm.HasCap, nm.SelfNode, peers) + conn25DNSTargets := appc.PickSplitDNSPeers(nm.HasCap, nm.SelfNode, peers, prefs.AppConnector().Advertise) if conn25DNSTargets != nil { var m map[string][]*dnstype.Resolver for domain, candidateSplitDNSPeers := range conn25DNSTargets { diff --git a/ipn/ipnlocal/node_backend_test.go b/ipn/ipnlocal/node_backend_test.go index f1f38dae6aee1..ca61624b8419b 100644 --- a/ipn/ipnlocal/node_backend_test.go +++ b/ipn/ipnlocal/node_backend_test.go @@ -12,7 +12,6 @@ import ( "tailscale.com/tailcfg" "tailscale.com/tstest" "tailscale.com/types/netmap" - "tailscale.com/types/ptr" "tailscale.com/util/eventbus" ) @@ -146,7 +145,7 @@ func TestNodeBackendReachability(t *testing.T) { name: "disabled/offline", cap: false, peer: tailcfg.Node{ - Online: ptr.To(false), + Online: new(false), }, want: false, }, @@ -154,7 +153,7 @@ func TestNodeBackendReachability(t *testing.T) { name: "disabled/online", cap: false, peer: tailcfg.Node{ - Online: ptr.To(true), + Online: new(true), }, want: true, }, @@ -162,7 +161,7 @@ func TestNodeBackendReachability(t *testing.T) { name: "enabled/offline", cap: true, peer: tailcfg.Node{ - Online: ptr.To(false), + Online: new(false), }, want: true, }, @@ -170,7 +169,7 @@ func TestNodeBackendReachability(t *testing.T) { name: "enabled/online", cap: true, peer: tailcfg.Node{ - Online: ptr.To(true), + Online: new(true), }, want: true, }, diff --git a/ipn/ipnlocal/peerapi.go b/ipn/ipnlocal/peerapi.go index aa4c1ef527c6c..d72a519ab1feb 100644 --- a/ipn/ipnlocal/peerapi.go +++ b/ipn/ipnlocal/peerapi.go @@ -103,7 +103,7 @@ func (s *peerAPIServer) listen(ip netip.Addr, tunIfIndex int) (ln net.Listener, // deterministic that people will bake this into clients. // We try a few times just in case something's already // listening on that port (on all interfaces, probably). - for try := uint8(0); try < 5; try++ { + for try := range uint8(5) { a16 := ip.As16() hashData := a16[len(a16)-3:] hashData[0] += try @@ -192,7 +192,7 @@ func (pln *peerAPIListener) ServeConn(src netip.AddrPort, c net.Conn) { c.Close() return } - nm := pln.lb.NetMap() + nm := pln.lb.NetMapNoPeers() if nm == nil || !nm.SelfNode.Valid() { logf("peerapi: no netmap") c.Close() diff --git a/ipn/ipnlocal/profiles.go b/ipn/ipnlocal/profiles.go index 430fa63152a77..4e073e5c9aeba 100644 --- a/ipn/ipnlocal/profiles.go +++ b/ipn/ipnlocal/profiles.go @@ -274,7 +274,7 @@ func (pm *profileManager) matchingProfiles(uid ipn.WindowsUserID, f func(ipn.Log func (pm *profileManager) findMatchingProfiles(uid ipn.WindowsUserID, prefs ipn.PrefsView) []ipn.LoginProfileView { return pm.matchingProfiles(uid, func(p ipn.LoginProfileView) bool { return p.ControlURL() == prefs.ControlURL() && - (p.UserProfile().ID == prefs.Persist().UserProfile().ID || + (p.UserProfile().ID() == prefs.Persist().UserProfile().ID() || p.NodeID() == prefs.Persist().NodeID()) }) } @@ -337,7 +337,7 @@ func (pm *profileManager) setUnattendedModeAsConfigured() error { // across user switches to disambiguate the same account but a different tailnet. func (pm *profileManager) SetPrefs(prefsIn ipn.PrefsView, np ipn.NetworkProfile) error { cp := pm.currentProfile - if persist := prefsIn.Persist(); !persist.Valid() || persist.NodeID() == "" || persist.UserProfile().LoginName == "" { + if persist := prefsIn.Persist(); !persist.Valid() || persist.NodeID() == "" || persist.UserProfile().LoginName() == "" { // We don't know anything about this profile, so ignore it for now. return pm.setProfilePrefsNoPermCheck(pm.currentProfile, prefsIn.AsStruct().View()) } @@ -410,7 +410,7 @@ func (pm *profileManager) setProfilePrefs(lp *ipn.LoginProfile, prefsIn ipn.Pref // and it hasn't been persisted yet. We'll generate both an ID and [ipn.StateKey] // once the information is available and needs to be persisted. if lp.ID == "" { - if persist := prefsIn.Persist(); persist.Valid() && persist.NodeID() != "" && persist.UserProfile().LoginName != "" { + if persist := prefsIn.Persist(); persist.Valid() && persist.NodeID() != "" && persist.UserProfile().LoginName() != "" { // Generate an ID and [ipn.StateKey] now that we have the node info. lp.ID, lp.Key = newUnusedID(pm.knownProfiles) } @@ -425,7 +425,7 @@ func (pm *profileManager) setProfilePrefs(lp *ipn.LoginProfile, prefsIn ipn.Pref var up tailcfg.UserProfile if persist := prefsIn.Persist(); persist.Valid() { - up = persist.UserProfile() + up = *persist.UserProfile().AsStruct() if up.DisplayName == "" { up.DisplayName = up.LoginName } diff --git a/ipn/ipnlocal/serve.go b/ipn/ipnlocal/serve.go index d25251accd797..83b8027d7c02c 100644 --- a/ipn/ipnlocal/serve.go +++ b/ipn/ipnlocal/serve.go @@ -276,7 +276,7 @@ func (b *LocalBackend) updateServeTCPPortNetMapAddrListenersLocked(ports []uint1 } } - nm := b.NetMap() + nm := b.NetMapNoPeers() if nm == nil { b.logf("netMap is nil") return @@ -333,7 +333,7 @@ func (b *LocalBackend) setServeConfigLocked(config *ipn.ServeConfig, etag string return errors.New("can't reconfigure tailscaled when using a config file; config file is locked") } - nm := b.NetMap() + nm := b.NetMapNoPeers() if nm == nil { return errors.New("netMap is nil") } @@ -835,8 +835,8 @@ func (b *LocalBackend) proxyHandlerForBackend(backend string) (http.Handler, err targetURL, insecure := expandProxyArg(backend) // Handle unix: scheme specially - if strings.HasPrefix(targetURL, "unix:") { - socketPath := strings.TrimPrefix(targetURL, "unix:") + if after, ok := strings.CutPrefix(targetURL, "unix:"); ok { + socketPath := after if socketPath == "" { return nil, fmt.Errorf("empty unix socket path") } diff --git a/ipn/ipnlocal/serve_test.go b/ipn/ipnlocal/serve_test.go index b3f48b105c8f7..05f4936b2c299 100644 --- a/ipn/ipnlocal/serve_test.go +++ b/ipn/ipnlocal/serve_test.go @@ -194,6 +194,7 @@ func TestGetServeHandler(t *testing.T) { b := &LocalBackend{ serveConfig: tt.conf.View(), logf: t.Logf, + health: health.NewTracker(eventbustest.NewBus(t)), } req := &http.Request{ URL: &url.URL{ @@ -619,49 +620,49 @@ func TestServeHTTPProxyPath(t *testing.T) { wantRequestPath string }{ { - name: "/foo -> /foo, with mount point and path /foo", + name: "foo-to-foo-mount-foo", mountPoint: "/foo", proxyPath: "/foo", requestPath: "/foo", wantRequestPath: "/foo", }, { - name: "/foo/ -> /foo/, with mount point and path /foo", + name: "foo-slash-to-foo-slash-mount-foo", mountPoint: "/foo", proxyPath: "/foo", requestPath: "/foo/", wantRequestPath: "/foo/", }, { - name: "/foo -> /foo/, with mount point and path /foo/", + name: "foo-to-foo-slash-mount-foo-slash", mountPoint: "/foo/", proxyPath: "/foo/", requestPath: "/foo", wantRequestPath: "/foo/", }, { - name: "/-> /, with mount point and path /", + name: "root-to-root-mount-root", mountPoint: "/", proxyPath: "/", requestPath: "/", wantRequestPath: "/", }, { - name: "/foo -> /foo, with mount point and path /", + name: "foo-to-foo-mount-root", mountPoint: "/", proxyPath: "/", requestPath: "/foo", wantRequestPath: "/foo", }, { - name: "/foo/bar -> /foo/bar, with mount point and path /foo", + name: "foo-bar-to-foo-bar-mount-foo", mountPoint: "/foo", proxyPath: "/foo", requestPath: "/foo/bar", wantRequestPath: "/foo/bar", }, { - name: "/foo/bar/baz -> /foo/bar/baz, with mount point and path /foo", + name: "foo-bar-baz-to-foo-bar-baz-mount-foo", mountPoint: "/foo", proxyPath: "/foo", requestPath: "/foo/bar/baz", @@ -1191,7 +1192,9 @@ func TestServeFileOrDirectory(t *testing.T) { } } - b := &LocalBackend{} + b := &LocalBackend{ + health: health.NewTracker(eventbustest.NewBus(t)), + } tests := []struct { req string @@ -1457,7 +1460,7 @@ func TestValidateServeConfigUpdate(t *testing.T) { wantError bool }{ { - name: "empty existing config", + name: "empty-existing-config", description: "should be able to update with empty existing config", existing: &ipn.ServeConfig{}, incoming: &ipn.ServeConfig{ @@ -1468,7 +1471,7 @@ func TestValidateServeConfigUpdate(t *testing.T) { wantError: false, }, { - name: "no existing config", + name: "no-existing-config", description: "should be able to update with no existing config", existing: nil, incoming: &ipn.ServeConfig{ @@ -1479,7 +1482,7 @@ func TestValidateServeConfigUpdate(t *testing.T) { wantError: false, }, { - name: "empty incoming config", + name: "empty-incoming-config", description: "wiping config should work", existing: &ipn.ServeConfig{ TCP: map[uint16]*ipn.TCPPortHandler{ @@ -1490,7 +1493,7 @@ func TestValidateServeConfigUpdate(t *testing.T) { wantError: false, }, { - name: "no incoming config", + name: "no-incoming-config", description: "missing incoming config should not result in an error", existing: &ipn.ServeConfig{ TCP: map[uint16]*ipn.TCPPortHandler{ @@ -1501,7 +1504,7 @@ func TestValidateServeConfigUpdate(t *testing.T) { wantError: false, }, { - name: "non-overlapping update", + name: "non-overlapping-update", description: "non-overlapping update should work", existing: &ipn.ServeConfig{ TCP: map[uint16]*ipn.TCPPortHandler{ @@ -1516,7 +1519,7 @@ func TestValidateServeConfigUpdate(t *testing.T) { wantError: false, }, { - name: "overwriting background port", + name: "overwriting-background-port", description: "should be able to overwrite a background port", existing: &ipn.ServeConfig{ TCP: map[uint16]*ipn.TCPPortHandler{ @@ -1535,7 +1538,7 @@ func TestValidateServeConfigUpdate(t *testing.T) { wantError: false, }, { - name: "broken existing config", + name: "broken-existing-config", description: "broken existing config should not prevent new config updates", existing: &ipn.ServeConfig{ TCP: map[uint16]*ipn.TCPPortHandler{ @@ -1573,7 +1576,7 @@ func TestValidateServeConfigUpdate(t *testing.T) { wantError: false, }, { - name: "services same port as background", + name: "services-same-port-as-background", description: "services should be able to use the same port as background listeners", existing: &ipn.ServeConfig{ TCP: map[uint16]*ipn.TCPPortHandler{ @@ -1592,7 +1595,7 @@ func TestValidateServeConfigUpdate(t *testing.T) { wantError: false, }, { - name: "services tun mode", + name: "services-tun-mode", description: "TUN mode should be mutually exclusive with TCP or web handlers for new Services", existing: &ipn.ServeConfig{}, incoming: &ipn.ServeConfig{ @@ -1608,7 +1611,7 @@ func TestValidateServeConfigUpdate(t *testing.T) { wantError: true, }, { - name: "new foreground listener", + name: "new-foreground-listener", description: "new foreground listeners must be on open ports", existing: &ipn.ServeConfig{ TCP: map[uint16]*ipn.TCPPortHandler{ @@ -1627,7 +1630,7 @@ func TestValidateServeConfigUpdate(t *testing.T) { wantError: true, }, { - name: "new background listener", + name: "new-background-listener", description: "new background listers cannot overwrite foreground listeners", existing: &ipn.ServeConfig{ Foreground: map[string]*ipn.ServeConfig{ @@ -1646,7 +1649,7 @@ func TestValidateServeConfigUpdate(t *testing.T) { wantError: true, }, { - name: "serve type overwrite", + name: "serve-type-overwrite", description: "incoming configuration cannot change the serve type in use by a port", existing: &ipn.ServeConfig{ TCP: map[uint16]*ipn.TCPPortHandler{ @@ -1665,7 +1668,7 @@ func TestValidateServeConfigUpdate(t *testing.T) { wantError: true, }, { - name: "serve type overwrite services", + name: "serve-type-overwrite-services", description: "incoming Services configuration cannot change the serve type in use by a port", existing: &ipn.ServeConfig{ Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ @@ -1692,7 +1695,7 @@ func TestValidateServeConfigUpdate(t *testing.T) { wantError: true, }, { - name: "tun mode with handlers", + name: "tun-mode-with-handlers", description: "Services cannot enable TUN mode if L4 or L7 handlers already exist", existing: &ipn.ServeConfig{ Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ @@ -1720,7 +1723,7 @@ func TestValidateServeConfigUpdate(t *testing.T) { wantError: true, }, { - name: "handlers with tun mode", + name: "handlers-with-tun-mode", description: "Services cannot add L4 or L7 handlers if TUN mode is already enabled", existing: &ipn.ServeConfig{ Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ diff --git a/ipn/ipnlocal/serve_unix_test.go b/ipn/ipnlocal/serve_unix_test.go index 2d1f0a1e34af8..9e641e0e521ba 100644 --- a/ipn/ipnlocal/serve_unix_test.go +++ b/ipn/ipnlocal/serve_unix_test.go @@ -8,6 +8,7 @@ package ipnlocal import ( "errors" "fmt" + "io" "net" "net/http" "net/http/httptest" @@ -22,26 +23,30 @@ import ( func TestExpandProxyArgUnix(t *testing.T) { tests := []struct { + name string input string wantURL string wantInsecure bool }{ { + name: "unix-tmp-sock", input: "unix:/tmp/test.sock", wantURL: "unix:/tmp/test.sock", }, { + name: "unix-var-run-docker-sock", input: "unix:/var/run/docker.sock", wantURL: "unix:/var/run/docker.sock", }, { + name: "unix-relative-sock", input: "unix:./relative.sock", wantURL: "unix:./relative.sock", }, } for _, tt := range tests { - t.Run(tt.input, func(t *testing.T) { + t.Run(tt.name, func(t *testing.T) { gotURL, gotInsecure := expandProxyArg(tt.input) if gotURL != tt.wantURL { t.Errorf("expandProxyArg(%q) url = %q, want %q", tt.input, gotURL, tt.wantURL) @@ -101,6 +106,23 @@ func TestServeUnixSocket(t *testing.T) { if rp.url.Host != "localhost" { t.Errorf("url.Host = %q, want %q", rp.url.Host, "localhost") } + + req := httptest.NewRequest("GET", "http://foo.test.ts.net/", nil) + rec := httptest.NewRecorder() + + rp.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatal("unexpected response code:", rec.Code) + } + resp := rec.Result() + defer resp.Body.Close() + respB, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal("read error:", err) + } + if string(respB) != testResponse { + t.Fatalf("unexpected response: want: '%s'; got: '%s'", testResponse, string(respB)) + } } func TestServeUnixSocketErrors(t *testing.T) { diff --git a/ipn/ipnlocal/ssh_stub.go b/ipn/ipnlocal/ssh_stub.go deleted file mode 100644 index 9a997c9143f7b..0000000000000 --- a/ipn/ipnlocal/ssh_stub.go +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -//go:build ts_omit_ssh || ios || android || (!linux && !darwin && !freebsd && !openbsd && !plan9) - -package ipnlocal - -import ( - "errors" - - "tailscale.com/tailcfg" -) - -func (b *LocalBackend) getSSHHostKeyPublicStrings() ([]string, error) { - return nil, nil -} - -func (b *LocalBackend) getSSHUsernames(*tailcfg.C2NSSHUsernamesRequest) (*tailcfg.C2NSSHUsernamesResponse, error) { - return nil, errors.New("not implemented") -} diff --git a/ipn/ipnlocal/ssh_test.go b/ipn/ipnlocal/ssh_test.go deleted file mode 100644 index bb293d10ac4d6..0000000000000 --- a/ipn/ipnlocal/ssh_test.go +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux || (darwin && !ios) - -package ipnlocal - -import ( - "encoding/json" - "reflect" - "testing" - - "tailscale.com/health" - "tailscale.com/ipn/store/mem" - "tailscale.com/tailcfg" - "tailscale.com/util/eventbus/eventbustest" - "tailscale.com/util/must" -) - -func TestSSHKeyGen(t *testing.T) { - dir := t.TempDir() - lb := &LocalBackend{varRoot: dir} - keys, err := lb.getTailscaleSSH_HostKeys(nil) - if err != nil { - t.Fatal(err) - } - got := map[string]bool{} - for _, k := range keys { - got[k.PublicKey().Type()] = true - } - want := map[string]bool{ - "ssh-rsa": true, - "ecdsa-sha2-nistp256": true, - "ssh-ed25519": true, - } - if !reflect.DeepEqual(got, want) { - t.Fatalf("keys = %v; want %v", got, want) - } - - keys2, err := lb.getTailscaleSSH_HostKeys(nil) - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(keys, keys2) { - t.Errorf("got different keys on second call") - } -} - -type fakeSSHServer struct { - SSHServer -} - -func TestGetSSHUsernames(t *testing.T) { - pm := must.Get(newProfileManager(new(mem.Store), t.Logf, health.NewTracker(eventbustest.NewBus(t)))) - b := &LocalBackend{pm: pm, store: pm.Store()} - b.sshServer = fakeSSHServer{} - res, err := b.getSSHUsernames(new(tailcfg.C2NSSHUsernamesRequest)) - if err != nil { - t.Fatal(err) - } - t.Logf("Got: %s", must.Get(json.Marshal(res))) -} diff --git a/ipn/ipnlocal/state_test.go b/ipn/ipnlocal/state_test.go index 39796ec325367..104c29a3f3e2b 100644 --- a/ipn/ipnlocal/state_test.go +++ b/ipn/ipnlocal/state_test.go @@ -136,6 +136,7 @@ type mockControl struct { calls []string authBlocked bool shutdown chan struct{} + loginFlags controlclient.LoginFlags hi *tailcfg.Hostinfo } @@ -273,6 +274,7 @@ func (cc *mockControl) Login(flags controlclient.LoginFlags) { cc.mu.Lock() defer cc.mu.Unlock() cc.authBlocked = interact || newKeys + cc.loginFlags |= flags } func (cc *mockControl) Logout(ctx context.Context) error { @@ -336,6 +338,8 @@ func (cc *mockControl) ClientID() int64 { return cc.controlClientID } +func (cc *mockControl) SetIPForwardingBroken(bool) {} + func (b *LocalBackend) nonInteractiveLoginForStateTest() { b.mu.Lock() if b.cc == nil { @@ -369,14 +373,6 @@ func (b *LocalBackend) nonInteractiveLoginForStateTest() { // predictable, but maybe a bit less thorough. This is more of an overall // state machine test than a test of the wgengine+magicsock integration. func TestStateMachine(t *testing.T) { - runTestStateMachine(t, false) -} - -func TestStateMachineSeamless(t *testing.T) { - runTestStateMachine(t, true) -} - -func runTestStateMachine(t *testing.T, seamless bool) { envknob.Setenv("TAILSCALE_USE_WIP_CODE", "1") defer envknob.Setenv("TAILSCALE_USE_WIP_CODE", "") c := qt.New(t) @@ -586,12 +582,6 @@ func runTestStateMachine(t *testing.T, seamless bool) { cc.persist.UserProfile.LoginName = "user1" cc.persist.NodeID = "node1" - // even if seamless is being enabled by default rather than by policy, this is - // the point where it will first get enabled. - if seamless { - sys.ControlKnobs().SeamlessKeyRenewal.Store(true) - } - cc.send(sendOpt{loginFinished: true, nm: &netmap.NetworkMap{}}) { nn := notifies.drain(3) @@ -606,7 +596,7 @@ func runTestStateMachine(t *testing.T, seamless bool) { cc.assertCalls() c.Assert(nn[0].LoginFinished, qt.IsNotNil) c.Assert(nn[1].Prefs, qt.IsNotNil) - c.Assert(nn[1].Prefs.Persist().UserProfile().LoginName, qt.Equals, "user1") + c.Assert(nn[1].Prefs.Persist().UserProfile().LoginName(), qt.Equals, "user1") // nn[2] is a state notification after login // Verify login finished but need machine auth using backend state c.Assert(isFullyAuthenticated(b), qt.IsTrue) @@ -694,6 +684,7 @@ func runTestStateMachine(t *testing.T, seamless bool) { notifies.expect(5) b.Logout(context.Background(), ipnauth.Self) { + b.awaitNoGoroutinesInTest() nn := notifies.drain(5) previousCC.assertCalls("pause", "Logout", "unpause", "Shutdown") // nn[0] is state notification (Stopped) @@ -818,7 +809,7 @@ func runTestStateMachine(t *testing.T, seamless bool) { c.Assert(nn[1].Prefs, qt.IsNotNil) c.Assert(nn[1].Prefs.Persist(), qt.IsNotNil) // Prefs after finishing the login, so LoginName updated. - c.Assert(nn[1].Prefs.Persist().UserProfile().LoginName, qt.Equals, "user2") + c.Assert(nn[1].Prefs.Persist().UserProfile().LoginName(), qt.Equals, "user2") c.Assert(nn[1].Prefs.LoggedOut(), qt.IsFalse) // If a user initiates an interactive login, they also expect WantRunning to become true. c.Assert(nn[1].Prefs.WantRunning(), qt.IsTrue) @@ -871,7 +862,9 @@ func runTestStateMachine(t *testing.T, seamless bool) { // additional netmap updates. Since our LocalBackend instance already // has a netmap, we will reset it to nil to simulate the first netmap // retrieval. + b.mu.Lock() b.setNetMapLocked(nil) + b.mu.Unlock() cc.assertCalls("unpause") // // TODO: really the various GUIs and prefs should be refactored to @@ -964,7 +957,7 @@ func runTestStateMachine(t *testing.T, seamless bool) { c.Assert(nn[0].LoginFinished, qt.IsNotNil) c.Assert(nn[1].Prefs, qt.IsNotNil) // Prefs after finishing the login, so LoginName updated. - c.Assert(nn[1].Prefs.Persist().UserProfile().LoginName, qt.Equals, "user3") + c.Assert(nn[1].Prefs.Persist().UserProfile().LoginName(), qt.Equals, "user3") c.Assert(nn[1].Prefs.LoggedOut(), qt.IsFalse) c.Assert(nn[1].Prefs.WantRunning(), qt.IsTrue) // nn[2] is state notification (Starting) - verify using backend state @@ -1050,6 +1043,7 @@ func runTestStateMachine(t *testing.T, seamless bool) { } notifies.expect(1) // Fake a DERP connection. + b.awaitNoGoroutinesInTest() b.setWgengineStatus(&wgengine.Status{DERPs: 1, AsOf: time.Now()}, nil) { nn := notifies.drain(1) @@ -1472,26 +1466,7 @@ func TestEngineReconfigOnStateChange(t *testing.T) { lb.StartLoginInteractive(context.Background()) cc().sendAuthURL(node1) }, - // Without seamless renewal, even starting a reauth tears down everything: - wantState: ipn.Starting, - wantCfg: &wgcfg.Config{}, - wantRouterCfg: &router.Config{}, - wantDNSCfg: &dns.Config{}, - }, - { - name: "Start/Connect/Login/InitReauth/Login", - steps: func(t *testing.T, lb *LocalBackend, cc func() *mockControl) { - mustDo(t)(lb.Start(ipn.Options{})) - mustDo2(t)(lb.EditPrefs(connect)) - cc().authenticated(node1) - - // Start the re-auth process: - lb.StartLoginInteractive(context.Background()) - cc().sendAuthURL(node1) - - // Complete the re-auth process: - cc().authenticated(node1) - }, + // Starting a reauth should leave everything up: wantState: ipn.Starting, wantCfg: &wgcfg.Config{ Peers: []wgcfg.Peer{}, @@ -1510,39 +1485,8 @@ func TestEngineReconfigOnStateChange(t *testing.T) { }, }, { - name: "Seamless/Start/Connect/Login/InitReauth", - steps: func(t *testing.T, lb *LocalBackend, cc func() *mockControl) { - lb.ControlKnobs().SeamlessKeyRenewal.Store(true) - mustDo(t)(lb.Start(ipn.Options{})) - mustDo2(t)(lb.EditPrefs(connect)) - cc().authenticated(node1) - - // Start the re-auth process: - lb.StartLoginInteractive(context.Background()) - cc().sendAuthURL(node1) - }, - // With seamless renewal, starting a reauth should leave everything up: - wantState: ipn.Starting, - wantCfg: &wgcfg.Config{ - Peers: []wgcfg.Peer{}, - Addresses: node1.SelfNode.Addresses().AsSlice(), - }, - wantRouterCfg: &router.Config{ - SNATSubnetRoutes: true, - NetfilterMode: preftype.NetfilterOn, - LocalAddrs: node1.SelfNode.Addresses().AsSlice(), - Routes: routesWithQuad100(), - }, - wantDNSCfg: &dns.Config{ - AcceptDNS: true, - Routes: map[dnsname.FQDN][]*dnstype.Resolver{}, - Hosts: hostsFor(node1), - }, - }, - { - name: "Seamless/Start/Connect/Login/InitReauth/Login", + name: "Start/Connect/Login/InitReauth/Login", steps: func(t *testing.T, lb *LocalBackend, cc func() *mockControl) { - lb.ControlKnobs().SeamlessKeyRenewal.Store(true) mustDo(t)(lb.Start(ipn.Options{})) mustDo2(t)(lb.EditPrefs(connect)) cc().authenticated(node1) @@ -1572,9 +1516,8 @@ func TestEngineReconfigOnStateChange(t *testing.T) { }, }, { - name: "Seamless/Start/Connect/Login/Expire", + name: "Start/Connect/Login/Expire", steps: func(t *testing.T, lb *LocalBackend, cc func() *mockControl) { - lb.ControlKnobs().SeamlessKeyRenewal.Store(true) mustDo(t)(lb.Start(ipn.Options{})) mustDo2(t)(lb.EditPrefs(connect)) cc().authenticated(node1) @@ -1584,7 +1527,7 @@ func TestEngineReconfigOnStateChange(t *testing.T) { }).View(), }}) }, - // Even with seamless, if the key we are using expires, we want to disconnect: + // If the key we are using expires, we want to disconnect: wantState: ipn.NeedsLogin, wantCfg: &wgcfg.Config{}, wantRouterCfg: &router.Config{}, @@ -1633,14 +1576,6 @@ func TestEngineReconfigOnStateChange(t *testing.T) { // TestSendPreservesAuthURL tests that wgengine updates arriving in the middle of // processing an auth URL doesn't result in the auth URL being cleared. func TestSendPreservesAuthURL(t *testing.T) { - runTestSendPreservesAuthURL(t, false) -} - -func TestSendPreservesAuthURLSeamless(t *testing.T) { - runTestSendPreservesAuthURL(t, true) -} - -func runTestSendPreservesAuthURL(t *testing.T, seamless bool) { var cc *mockControl b := newLocalBackendWithTestControl(t, true, func(tb testing.TB, opts controlclient.Options) controlclient.Client { cc = newClient(t, opts) @@ -1659,10 +1594,6 @@ func runTestSendPreservesAuthURL(t *testing.T, seamless bool) { cc.persist.UserProfile.LoginName = "user1" cc.persist.NodeID = "node1" - if seamless { - b.sys.ControlKnobs().SeamlessKeyRenewal.Store(true) - } - cc.send(sendOpt{loginFinished: true, nm: &netmap.NetworkMap{ SelfNode: (&tailcfg.Node{MachineAuthorized: true}).View(), }}) @@ -2007,6 +1938,8 @@ func (e *mockEngine) Ping(ip netip.Addr, pingType tailcfg.PingType, size int, cb func (e *mockEngine) InstallCaptureHook(packet.CaptureCallback) {} +func (e *mockEngine) SetPeerByIPPacketFunc(func(netip.Addr) (_ key.NodePublic, ok bool)) {} + func (e *mockEngine) Close() { e.mu.Lock() defer e.mu.Unlock() diff --git a/ipn/ipnlocal/web_client.go b/ipn/ipnlocal/web_client.go index 37dba93d0a49b..6ab68858ee701 100644 --- a/ipn/ipnlocal/web_client.go +++ b/ipn/ipnlocal/web_client.go @@ -173,7 +173,7 @@ func (b *LocalBackend) waitWebClientAuthURL(ctx context.Context, id string, src // one to be completed, based on the presence or absence of the // provided id value. func (b *LocalBackend) doWebClientNoiseRequest(ctx context.Context, id string, src tailcfg.NodeID) (*tailcfg.WebClientAuthResponse, error) { - nm := b.NetMap() + nm := b.NetMapNoPeers() if nm == nil || !nm.SelfNode.Valid() { return nil, errors.New("[unexpected] no self node") } diff --git a/ipn/ipnserver/actor.go b/ipn/ipnserver/actor.go index c9a4c6e891f86..985a6ef7a712d 100644 --- a/ipn/ipnserver/actor.go +++ b/ipn/ipnserver/actor.go @@ -145,7 +145,7 @@ func (a *actor) Username() (string, error) { } defer tok.Close() return tok.Username() - case "darwin", "linux", "illumos", "solaris", "openbsd": + case "darwin", "linux", "illumos", "solaris", "openbsd", "freebsd": creds := a.ci.Creds() if creds == nil { return "", errors.New("peer credentials not implemented on this OS") @@ -237,7 +237,7 @@ func connIsLocalAdmin(logf logger.Logf, ci *ipnauth.ConnIdentity, operatorUID st // This is a standalone tailscaled setup, use the same logic as on // Linux. fallthrough - case "linux": + case "linux", "solaris", "illumos": if !buildfeatures.HasUnixSocketIdentity { // Everybody is an admin if support for unix socket identities // is omitted for the build. diff --git a/ipn/ipnserver/server.go b/ipn/ipnserver/server.go index 1f8abf0e20128..19efaf9895b94 100644 --- a/ipn/ipnserver/server.go +++ b/ipn/ipnserver/server.go @@ -429,7 +429,7 @@ func (s *Server) addActiveHTTPRequest(req *http.Request, actor ipnauth.Actor) (o if len(s.activeReqs) == 1 { if envknob.GOOS() == "windows" && !actor.IsLocalSystem() { // Tell the LocalBackend about the identity we're now running as, - // unless its the SYSTEM user. That user is not a real account and + // unless it's the SYSTEM user. That user is not a real account and // doesn't have a home directory. lb.SetCurrentUser(actor) } diff --git a/ipn/ipnserver/server_test.go b/ipn/ipnserver/server_test.go index 9aa9c4c015f23..45a8d622d3e73 100644 --- a/ipn/ipnserver/server_test.go +++ b/ipn/ipnserver/server_test.go @@ -16,7 +16,6 @@ import ( "tailscale.com/ipn" "tailscale.com/ipn/lapitest" "tailscale.com/tsd" - "tailscale.com/types/ptr" "tailscale.com/util/syspolicy/pkey" "tailscale.com/util/syspolicy/policytest" ) @@ -49,7 +48,7 @@ func TestUserConnectDisconnectNonWindows(t *testing.T) { // And if we send a notification, both users should receive it. wantErrMessage := "test error" - testNotify := ipn.Notify{ErrMessage: ptr.To(wantErrMessage)} + testNotify := ipn.Notify{ErrMessage: new(wantErrMessage)} server.Backend().DebugNotify(testNotify) if n, err := watcherA.Next(); err != nil { @@ -274,12 +273,12 @@ func TestShutdownViaLocalAPI(t *testing.T) { }, { name: "AllowTailscaledRestart/False", - allowTailscaledRestart: ptr.To(false), + allowTailscaledRestart: new(false), wantErr: errAccessDeniedByPolicy, }, { name: "AllowTailscaledRestart/True", - allowTailscaledRestart: ptr.To(true), + allowTailscaledRestart: new(true), wantErr: nil, // shutdown should be allowed }, } diff --git a/ipn/ipnstate/ipnstate.go b/ipn/ipnstate/ipnstate.go index 4d219d131d528..17e6ac870bead 100644 --- a/ipn/ipnstate/ipnstate.go +++ b/ipn/ipnstate/ipnstate.go @@ -20,7 +20,6 @@ import ( "tailscale.com/tailcfg" "tailscale.com/tka" "tailscale.com/types/key" - "tailscale.com/types/ptr" "tailscale.com/types/views" "tailscale.com/util/dnsname" "tailscale.com/version" @@ -535,7 +534,7 @@ func (sb *StatusBuilder) AddPeer(peer key.NodePublic, st *PeerStatus) { e.Expired = true } if t := st.KeyExpiry; t != nil { - e.KeyExpiry = ptr.To(*t) + e.KeyExpiry = new(*t) } if v := st.CapMap; v != nil { e.CapMap = v diff --git a/ipn/lapitest/server.go b/ipn/lapitest/server.go index 8fd3c8cdd361f..2686682af15c9 100644 --- a/ipn/lapitest/server.go +++ b/ipn/lapitest/server.go @@ -22,7 +22,6 @@ import ( "tailscale.com/ipn/ipnserver" "tailscale.com/types/logger" "tailscale.com/types/logid" - "tailscale.com/types/ptr" "tailscale.com/util/mak" "tailscale.com/util/rands" ) @@ -153,7 +152,7 @@ func (s *Server) MakeTestActor(name string, clientID string) *ipnauth.TestActor } // Create a shallow copy of the base actor and assign it the new client ID. - actor := ptr.To(*baseActor) + actor := new(*baseActor) actor.CID = ipnauth.ClientIDFrom(clientID) return actor } diff --git a/ipn/localapi/debug.go b/ipn/localapi/debug.go index d1348abaafef5..6f222bef08ac3 100644 --- a/ipn/localapi/debug.go +++ b/ipn/localapi/debug.go @@ -9,6 +9,7 @@ import ( "cmp" "context" "encoding/json" + "errors" "fmt" "io" "net" @@ -142,14 +143,11 @@ func (h *Handler) serveDebugDialTypes(w http.ResponseWriter, r *http.Request) { var wg sync.WaitGroup for _, dialer := range dialers { - dialer := dialer // loop capture - wg.Add(1) - go func() { - defer wg.Done() + wg.Go(func() { conn, err := dialer.dial(ctx, network, addr) results <- result{dialer.name, conn, err} - }() + }) } wg.Wait() @@ -235,8 +233,39 @@ func (h *Handler) serveDebug(w http.ResponseWriter, r *http.Request) { if err == nil { return } + case "peer-disco-keys": + w.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w).Encode(h.b.DebugPeerDiscoKeys()) + if err == nil { + return + } case "rotate-disco-key": err = h.b.DebugRotateDiscoKey() + case "statedir": + root := h.b.TailscaleVarRoot() + w.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w).Encode(root) + if err == nil { + return + } + case "clear-netmap-cache": + h.b.ClearNetmapCache(r.Context()) + case "current-netmap": + // Return the current netmap (with peers populated) as JSON. This + // is a debug-only path: the netmap.NetworkMap shape is an + // internal type and may change without notice. Production + // callers should fetch the narrower bits they need via their + // own LocalAPI methods instead. + nm := h.b.NetMapWithPeers() + if nm == nil { + err = errors.New("no netmap") + break + } + w.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w).Encode(nm) + if err == nil { + return + } case "": err = fmt.Errorf("missing parameter 'action'") default: @@ -272,7 +301,7 @@ func (h *Handler) serveDebugPacketFilterRules(w http.ResponseWriter, r *http.Req http.Error(w, "debug access denied", http.StatusForbidden) return } - nm := h.b.NetMap() + nm := h.b.NetMapNoPeers() if nm == nil { http.Error(w, "no netmap", http.StatusNotFound) return @@ -289,7 +318,7 @@ func (h *Handler) serveDebugPacketFilterMatches(w http.ResponseWriter, r *http.R http.Error(w, "debug access denied", http.StatusForbidden) return } - nm := h.b.NetMap() + nm := h.b.NetMapNoPeers() if nm == nil { http.Error(w, "no netmap", http.StatusNotFound) return diff --git a/ipn/localapi/localapi.go b/ipn/localapi/localapi.go index ed25e875da409..6375f440d1e85 100644 --- a/ipn/localapi/localapi.go +++ b/ipn/localapi/localapi.go @@ -43,7 +43,6 @@ import ( "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/logid" - "tailscale.com/types/ptr" "tailscale.com/util/clientmetric" "tailscale.com/util/eventbus" "tailscale.com/util/httpm" @@ -73,16 +72,20 @@ var handler = map[string]LocalAPIHandler{ // The other /localapi/v0/NAME handlers are exact matches and contain only NAME // without a trailing slash: + "cert-domains": (*Handler).serveCertDomains, "check-prefs": (*Handler).serveCheckPrefs, "check-so-mark-in-use": (*Handler).serveCheckSOMarkInUse, "derpmap": (*Handler).serveDERPMap, + "dns-config": (*Handler).serveDNSConfig, "goroutines": (*Handler).serveGoroutines, "login-interactive": (*Handler).serveLoginInteractive, "logout": (*Handler).serveLogout, + "peer-by-id": (*Handler).servePeerByID, "ping": (*Handler).servePing, "prefs": (*Handler).servePrefs, "reload-config": (*Handler).reloadConfig, "reset-auth": (*Handler).serveResetAuth, + "services": (*Handler).serveServices, "set-expiry-sooner": (*Handler).serveSetExpirySooner, "shutdown": (*Handler).serveShutdown, "start": (*Handler).serveStart, @@ -116,7 +119,7 @@ func init() { Register("bugreport", (*Handler).serveBugReport) Register("pprof", (*Handler).servePprof) } - if buildfeatures.HasDebug || buildfeatures.HasServe { + if buildfeatures.HasIPNBus { Register("watch-ipn-bus", (*Handler).serveWatchIPNBus) } if buildfeatures.HasDNS { @@ -347,7 +350,7 @@ func (h *Handler) serveIDToken(w http.ResponseWriter, r *http.Request) { http.Error(w, "id-token access denied", http.StatusForbidden) return } - nm := h.b.NetMap() + nm := h.b.NetMapNoPeers() if nm == nil { http.Error(w, "no netmap", http.StatusServiceUnavailable) return @@ -417,7 +420,7 @@ func (h *Handler) serveBugReport(w http.ResponseWriter, r *http.Request) { } // Information about the current node from the netmap - if nm := h.b.NetMap(); nm != nil { + if nm := h.b.NetMapNoPeers(); nm != nil { if self := nm.SelfNode; self.Valid() { h.logf("user bugreport node info: nodeid=%q stableid=%q expiry=%q", self.ID(), self.StableID(), self.KeyExpiry().Format(time.RFC3339)) } @@ -845,8 +848,8 @@ func InUseOtherUserIPNStream(w http.ResponseWriter, r *http.Request, err error) } js, err := json.Marshal(&ipn.Notify{ Version: version.Long(), - State: ptr.To(ipn.InUseOtherUser), - ErrMessage: ptr.To(err.Error()), + State: new(ipn.InUseOtherUser), + ErrMessage: new(err.Error()), }) if err != nil { return false @@ -1073,6 +1076,80 @@ func (h *Handler) serveDERPMap(w http.ResponseWriter, r *http.Request) { e.Encode(h.b.DERPMap()) } +// serveCertDomains returns the list of DNS.CertDomains from the current +// netmap, or an empty list if no netmap has been received yet. +// The returned list is sorted in ascending order. +func (h *Handler) serveCertDomains(w http.ResponseWriter, r *http.Request) { + if !h.PermitRead { + http.Error(w, "cert-domains access denied", http.StatusForbidden) + return + } + var domains []string + if nm := h.b.NetMapNoPeers(); nm != nil { + domains = slices.Clone(nm.DNS.CertDomains) + slices.Sort(domains) + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(domains) +} + +// serveDNSConfig returns the [tailcfg.DNSConfig] from the current netmap. +// It returns 503 if no netmap has been received yet. +func (h *Handler) serveDNSConfig(w http.ResponseWriter, r *http.Request) { + if !h.PermitRead { + http.Error(w, "dns-config access denied", http.StatusForbidden) + return + } + nm := h.b.NetMapNoPeers() + if nm == nil { + http.Error(w, "no netmap", http.StatusServiceUnavailable) + return + } + w.Header().Set("Content-Type", "application/json") + e := json.NewEncoder(w) + e.SetIndent("", "\t") + e.Encode(nm.DNS) +} + +// peerByIDBackend is the subset of [ipnlocal.LocalBackend] used by +// [Handler.servePeerByID]. It exists so the handler can be tested with a +// trivial mock without spinning up a full LocalBackend. +type peerByIDBackend interface { + PeerByID(tailcfg.NodeID) (tailcfg.NodeView, bool) +} + +// servePeerByID returns the current full [tailcfg.Node] for the peer with +// the NodeID given in the "id" query parameter, in O(1) time. It returns +// 404 if no such peer is in the current netmap. +// +// It is intended for clients that need the latest state of a single peer +// without fetching the entire netmap. +func (h *Handler) servePeerByID(w http.ResponseWriter, r *http.Request) { + h.servePeerByIDWithBackend(w, r, h.b) +} + +func (h *Handler) servePeerByIDWithBackend(w http.ResponseWriter, r *http.Request, b peerByIDBackend) { + if !h.PermitRead { + http.Error(w, "peer-by-id access denied", http.StatusForbidden) + return + } + idStr := r.FormValue("id") + id, err := strconv.ParseInt(idStr, 10, 64) + if err != nil || id <= 0 { + http.Error(w, "invalid 'id' parameter", http.StatusBadRequest) + return + } + nv, ok := b.PeerByID(tailcfg.NodeID(id)) + if !ok { + http.Error(w, "no peer with that NodeID", http.StatusNotFound) + return + } + w.Header().Set("Content-Type", "application/json") + e := json.NewEncoder(w) + e.SetIndent("", "\t") + e.Encode(nv.AsStruct()) +} + // serveSetExpirySooner sets the expiry date on the current machine, specified // by an `expiry` unix timestamp as POST or query param. func (h *Handler) serveSetExpirySooner(w http.ResponseWriter, r *http.Request) { @@ -1169,16 +1246,34 @@ func (h *Handler) serveDial(w http.ResponseWriter, r *http.Request) { http.Error(w, "missing Dial-Host or Dial-Port header", http.StatusBadRequest) return } + network := cmp.Or(r.Header.Get("Dial-Network"), "tcp") + + addr := net.JoinHostPort(hostStr, portStr) + + // Check whether the resolved address is a Tailscale route. + // If not, tell the client to dial it directly so the connection + // comes from the calling user's UID rather than our root-owned daemon. + ipp, viaTailscale, err := h.b.Dialer().UserDialPlan(r.Context(), network, addr) + if err != nil { + http.Error(w, "resolve failure: "+err.Error(), http.StatusBadGateway) + return + } + if !viaTailscale { + w.Header().Set("Dial-Self", "true") + w.Header().Set("Dial-Addr", ipp.String()) + w.WriteHeader(http.StatusOK) + return + } + hijacker, ok := w.(http.Hijacker) if !ok { http.Error(w, "make request over HTTP/1", http.StatusBadRequest) return } - network := cmp.Or(r.Header.Get("Dial-Network"), "tcp") - - addr := net.JoinHostPort(hostStr, portStr) - outConn, err := h.b.Dialer().UserDial(r.Context(), network, addr) + // Dial via Tailscale using the resolved IP:port to avoid a TOCTOU + // race with DNS re-resolution. + outConn, err := h.b.Dialer().UserDial(r.Context(), network, ipp.String()) if err != nil { http.Error(w, "dial failure: "+err.Error(), http.StatusBadGateway) return @@ -1458,7 +1553,7 @@ func (h *Handler) serveQueryFeature(w http.ResponseWriter, r *http.Request) { http.Error(w, "missing feature", http.StatusInternalServerError) return } - nm := h.b.NetMap() + nm := h.b.NetMapNoPeers() if nm == nil { http.Error(w, "no netmap", http.StatusServiceUnavailable) return @@ -1708,6 +1803,20 @@ func (h *Handler) serveShutdown(w http.ResponseWriter, r *http.Request) { eventbus.Publish[Shutdown](ec).Publish(Shutdown{}) } +func (h *Handler) serveServices(w http.ResponseWriter, r *http.Request) { + if r.Method != httpm.GET { + http.Error(w, "only GET allowed", http.StatusMethodNotAllowed) + return + } + nm := h.b.NetMapNoPeers() + if nm == nil { + http.Error(w, "no netmap", http.StatusServiceUnavailable) + return + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(nm.Services()) +} + func (h *Handler) serveGetAppcRouteInfo(w http.ResponseWriter, r *http.Request) { if !buildfeatures.HasAppConnectors { http.Error(w, feature.ErrUnavailable.Error(), http.StatusNotImplemented) diff --git a/ipn/localapi/localapi_test.go b/ipn/localapi/localapi_test.go index 47e33457188ab..352f71e00466d 100644 --- a/ipn/localapi/localapi_test.go +++ b/ipn/localapi/localapi_test.go @@ -202,6 +202,72 @@ func TestWhoIsArgTypes(t *testing.T) { } } +type fakePeerByIDBackend map[tailcfg.NodeID]*tailcfg.Node + +func (f fakePeerByIDBackend) PeerByID(id tailcfg.NodeID) (tailcfg.NodeView, bool) { + n, ok := f[id] + if !ok { + return tailcfg.NodeView{}, false + } + return n.View(), true +} + +func TestServePeerByID(t *testing.T) { + h := handlerForTest(t, &Handler{PermitRead: true}) + b := fakePeerByIDBackend{ + 42: { + ID: 42, + Name: "alpha", + Addresses: []netip.Prefix{ + netip.MustParsePrefix("100.64.0.42/32"), + }, + }, + } + + tests := []struct { + name string + query string + wantCode int + wantNodeID tailcfg.NodeID + }{ + {"hit", "id=42", 200, 42}, + {"miss", "id=99", 404, 0}, + {"bad_id", "id=garbage", 400, 0}, + {"missing_id", "", 400, 0}, + {"zero_id", "id=0", 400, 0}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rec := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/v0/peer-by-id?"+tt.query, nil) + h.servePeerByIDWithBackend(rec, req, b) + if rec.Code != tt.wantCode { + t.Fatalf("status = %d, want %d; body=%q", rec.Code, tt.wantCode, rec.Body.String()) + } + if tt.wantCode != 200 { + return + } + var got tailcfg.Node + if err := json.Unmarshal(rec.Body.Bytes(), &got); err != nil { + t.Fatalf("unmarshal body %q: %v", rec.Body.Bytes(), err) + } + if got.ID != tt.wantNodeID { + t.Errorf("Node.ID = %d, want %d", got.ID, tt.wantNodeID) + } + }) + } + + t.Run("forbidden", func(t *testing.T) { + hh := handlerForTest(t, &Handler{PermitRead: false}) + rec := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/v0/peer-by-id?id=42", nil) + hh.servePeerByIDWithBackend(rec, req, b) + if rec.Code != http.StatusForbidden { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusForbidden) + } + }) +} + func TestShouldDenyServeConfigForGOOSAndUserContext(t *testing.T) { newHandler := func(connIsLocalAdmin bool) *Handler { return handlerForTest(t, &Handler{ @@ -500,3 +566,69 @@ func TestServeWithUnhealthyState(t *testing.T) { }) } } + +func TestServeDialSelf(t *testing.T) { + h := handlerForTest(t, &Handler{ + PermitRead: true, + PermitWrite: true, + b: newTestLocalBackend(t), + }) + + tests := []struct { + name string + host string + port string + wantSelf bool + wantAddr string + wantStatus int + }{ + { + name: "loopback_v4", + host: "127.0.0.1", + port: "8080", + wantSelf: true, + wantAddr: "127.0.0.1:8080", + wantStatus: http.StatusOK, + }, + { + name: "loopback_v6", + host: "::1", + port: "8080", + wantSelf: true, + wantAddr: "[::1]:8080", + wantStatus: http.StatusOK, + }, + { + name: "localhost", + host: "localhost", + port: "3000", + wantSelf: true, + wantStatus: http.StatusOK, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("POST", "http://local-tailscaled.sock/localapi/v0/dial", nil) + req.Header.Set("Connection", "upgrade") + req.Header.Set("Upgrade", "ts-dial") + req.Header.Set("Dial-Host", tt.host) + req.Header.Set("Dial-Port", tt.port) + resp := httptest.NewRecorder() + h.serveDial(resp, req) + + if resp.Code != tt.wantStatus { + t.Fatalf("status = %d, want %d; body: %s", resp.Code, tt.wantStatus, resp.Body.String()) + } + gotSelf := resp.Header().Get("Dial-Self") == "true" + if gotSelf != tt.wantSelf { + t.Errorf("Dial-Self = %v, want %v", gotSelf, tt.wantSelf) + } + if tt.wantAddr != "" { + if got := resp.Header().Get("Dial-Addr"); got != tt.wantAddr { + t.Errorf("Dial-Addr = %q, want %q", got, tt.wantAddr) + } + } + }) + } +} diff --git a/ipn/localapi/tailnetlock.go b/ipn/localapi/tailnetlock.go index 445f705056cf7..e2a2850cf331d 100644 --- a/ipn/localapi/tailnetlock.go +++ b/ipn/localapi/tailnetlock.go @@ -122,7 +122,7 @@ func (h *Handler) serveTKAInit(w http.ResponseWriter, r *http.Request) { func (h *Handler) serveTKAModify(w http.ResponseWriter, r *http.Request) { if !h.PermitWrite { - http.Error(w, "network-lock modify access denied", http.StatusForbidden) + http.Error(w, "tailnet-lock modify access denied", http.StatusForbidden) return } if r.Method != httpm.POST { @@ -141,7 +141,7 @@ func (h *Handler) serveTKAModify(w http.ResponseWriter, r *http.Request) { } if err := h.b.NetworkLockModify(req.AddKeys, req.RemoveKeys); err != nil { - http.Error(w, "network-lock modify failed: "+err.Error(), http.StatusInternalServerError) + http.Error(w, "tailnet-lock modify failed: "+err.Error(), http.StatusInternalServerError) return } w.WriteHeader(204) @@ -149,7 +149,7 @@ func (h *Handler) serveTKAModify(w http.ResponseWriter, r *http.Request) { func (h *Handler) serveTKAWrapPreauthKey(w http.ResponseWriter, r *http.Request) { if !h.PermitWrite { - http.Error(w, "network-lock modify access denied", http.StatusForbidden) + http.Error(w, "tailnet-lock modify access denied", http.StatusForbidden) return } if r.Method != httpm.POST { @@ -212,7 +212,7 @@ func (h *Handler) serveTKAVerifySigningDeeplink(w http.ResponseWriter, r *http.R func (h *Handler) serveTKADisable(w http.ResponseWriter, r *http.Request) { if !h.PermitWrite { - http.Error(w, "network-lock modify access denied", http.StatusForbidden) + http.Error(w, "tailnet-lock modify access denied", http.StatusForbidden) return } if r.Method != httpm.POST { @@ -228,7 +228,7 @@ func (h *Handler) serveTKADisable(w http.ResponseWriter, r *http.Request) { } if err := h.b.NetworkLockDisable(secret); err != nil { - http.Error(w, "network-lock disable failed: "+err.Error(), http.StatusBadRequest) + http.Error(w, "tailnet-lock disable failed: "+err.Error(), http.StatusBadRequest) return } w.WriteHeader(http.StatusOK) @@ -236,7 +236,7 @@ func (h *Handler) serveTKADisable(w http.ResponseWriter, r *http.Request) { func (h *Handler) serveTKALocalDisable(w http.ResponseWriter, r *http.Request) { if !h.PermitWrite { - http.Error(w, "network-lock modify access denied", http.StatusForbidden) + http.Error(w, "tailnet-lock modify access denied", http.StatusForbidden) return } if r.Method != httpm.POST { @@ -252,7 +252,7 @@ func (h *Handler) serveTKALocalDisable(w http.ResponseWriter, r *http.Request) { } if err := h.b.NetworkLockForceLocalDisable(); err != nil { - http.Error(w, "network-lock local disable failed: "+err.Error(), http.StatusBadRequest) + http.Error(w, "tailnet-lock local disable failed: "+err.Error(), http.StatusBadRequest) return } w.WriteHeader(http.StatusOK) diff --git a/ipn/prefs.go b/ipn/prefs.go index 72e0cf8b78424..9125df2c1a76b 100644 --- a/ipn/prefs.go +++ b/ipn/prefs.go @@ -439,12 +439,11 @@ func applyPrefsEdits(src, dst reflect.Value, mask map[string]reflect.Value) { func maskFields(v reflect.Value) map[string]reflect.Value { mask := make(map[string]reflect.Value) - for i := range v.NumField() { - f := v.Type().Field(i).Name - if !strings.HasSuffix(f, "Set") { + for sf, fv := range v.Fields() { + if !strings.HasSuffix(sf.Name, "Set") { continue } - mask[strings.TrimSuffix(f, "Set")] = v.Field(i) + mask[strings.TrimSuffix(sf.Name, "Set")] = fv } return mask } @@ -845,22 +844,15 @@ func (p *Prefs) SetAdvertiseExitNode(runExit bool) { // Tailscale IP. func peerWithTailscaleIP(st *ipnstate.Status, ip netip.Addr) (ps *ipnstate.PeerStatus, ok bool) { for _, ps := range st.Peer { - for _, ip2 := range ps.TailscaleIPs { - if ip == ip2 { - return ps, true - } + if slices.Contains(ps.TailscaleIPs, ip) { + return ps, true } } return nil, false } func isRemoteIP(st *ipnstate.Status, ip netip.Addr) bool { - for _, selfIP := range st.TailscaleIPs { - if ip == selfIP { - return false - } - } - return true + return !slices.Contains(st.TailscaleIPs, ip) } // ClearExitNode sets the ExitNodeID and ExitNodeIP to their zero values. @@ -904,8 +896,17 @@ func exitNodeIPOfArg(s string, st *ipnstate.Status) (ip netip.Addr, err error) { } match := 0 for _, ps := range st.Peer { - baseName := dnsname.TrimSuffix(ps.DNSName, st.MagicDNSSuffix) - if !strings.EqualFold(s, baseName) && !strings.EqualFold(s, ps.DNSName) { + // Compare to the peer name in three forms: + // + // - base name ("example") + // - FQDN ("example.tail1234.ts.net.") + // - FQDN sans dot ("example.tail1234.ts.net", as returned by `tailscale exit-node list` + // and the admin console) + // + fqdn := ps.DNSName + baseName := dnsname.TrimSuffix(fqdn, st.MagicDNSSuffix) + fqdnSansDot := dnsname.TrimSuffix(fqdn, ".") + if !strings.EqualFold(s, baseName) && !strings.EqualFold(s, fqdn) && !strings.EqualFold(s, fqdnSansDot) { continue } match++ @@ -919,7 +920,7 @@ func exitNodeIPOfArg(s string, st *ipnstate.Status) (ip netip.Addr, err error) { } switch match { case 0: - return ip, fmt.Errorf("invalid value %q for --exit-node; must be IP or unique node name", s) + return ip, fmt.Errorf("invalid value %q for --exit-node; must be IP or hostname", s) case 1: if !isRemoteIP(st, ip) { return ip, ExitNodeLocalIPError{s} diff --git a/ipn/prefs_test.go b/ipn/prefs_test.go index 347a91e50739c..31dd2c55a3182 100644 --- a/ipn/prefs_test.go +++ b/ipn/prefs_test.go @@ -27,8 +27,8 @@ import ( ) func fieldsOf(t reflect.Type) (fields []string) { - for i := range t.NumField() { - fields = append(fields, t.Field(i).Name) + for field := range t.Fields() { + fields = append(fields, field.Name) } return } @@ -458,7 +458,7 @@ func TestPrefsFromBytesPreservesOldValues(t *testing.T) { want: Prefs{ControlURL: "https://foo", RouteAll: true}, }, { - name: "opt.Bool", // test that we don't normalize it early + name: "opt-Bool", // test that we don't normalize it early old: Prefs{Sync: "unset"}, json: []byte(`{}`), want: Prefs{Sync: "unset"}, @@ -1009,7 +1009,7 @@ func TestExitNodeIPOfArg(t *testing.T) { name: "no_match", arg: "unknown", st: &ipnstate.Status{MagicDNSSuffix: ".foo"}, - wantErr: `invalid value "unknown" for --exit-node; must be IP or unique node name`, + wantErr: `invalid value "unknown" for --exit-node; must be IP or hostname`, }, { name: "name", @@ -1041,6 +1041,21 @@ func TestExitNodeIPOfArg(t *testing.T) { }, want: mustIP("1.0.0.2"), }, + { + name: "name_fqdn_sans_dot", + arg: "skippy.foo", + st: &ipnstate.Status{ + MagicDNSSuffix: ".foo", + Peer: map[key.NodePublic]*ipnstate.PeerStatus{ + key.NewNode().Public(): { + DNSName: "skippy.foo.", + TailscaleIPs: []netip.Addr{mustIP("1.0.0.2")}, + ExitNodeOption: true, + }, + }, + }, + want: mustIP("1.0.0.2"), + }, { name: "name_not_exit", arg: "skippy", @@ -1067,7 +1082,7 @@ func TestExitNodeIPOfArg(t *testing.T) { }, }, }, - wantErr: `invalid value "skippy.bar." for --exit-node; must be IP or unique node name`, + wantErr: `invalid value "skippy.bar." for --exit-node; must be IP or hostname`, }, { name: "ambiguous", @@ -1221,13 +1236,13 @@ func TestParseAutoExitNodeString(t *testing.T) { wantExpr ExitNodeExpression }{ { - name: "empty expr", + name: "empty-expr", exitNodeID: "", wantOk: false, wantExpr: "", }, { - name: "no auto prefix", + name: "no-auto-prefix", exitNodeID: "foo", wantOk: false, wantExpr: "", @@ -1245,13 +1260,13 @@ func TestParseAutoExitNodeString(t *testing.T) { wantExpr: "foo", }, { - name: "auto prefix but empty suffix", + name: "auto-prefix-empty-suffix", exitNodeID: "auto:", wantOk: false, wantExpr: "", }, { - name: "auto prefix no colon", + name: "auto-prefix-no-colon", exitNodeID: "auto", wantOk: false, wantExpr: "", diff --git a/ipn/serve.go b/ipn/serve.go index 911b408b65026..21d15ab818fc9 100644 --- a/ipn/serve.go +++ b/ipn/serve.go @@ -673,7 +673,7 @@ func CheckFunnelPort(wantedPort uint16, node *ipnstate.PeerStatus) error { return deny("") } wantedPortString := strconv.Itoa(int(wantedPort)) - for _, ps := range strings.Split(portsStr, ",") { + for ps := range strings.SplitSeq(portsStr, ",") { if ps == "" { continue } diff --git a/ipn/serve_test.go b/ipn/serve_test.go index 8be39a1ed81ce..bf043ca39f372 100644 --- a/ipn/serve_test.go +++ b/ipn/serve_test.go @@ -283,11 +283,11 @@ func TestExpandProxyTargetDev(t *testing.T) { wantErr bool }{ {name: "port-only", input: "8080", expected: "http://127.0.0.1:8080"}, - {name: "hostname+port", input: "localhost:8080", expected: "http://localhost:8080"}, + {name: "hostname-and-port", input: "localhost:8080", expected: "http://localhost:8080"}, {name: "no-change", input: "http://127.0.0.1:8080", expected: "http://127.0.0.1:8080"}, {name: "include-path", input: "http://127.0.0.1:8080/foo", expected: "http://127.0.0.1:8080/foo"}, {name: "https-scheme", input: "https://localhost:8080", expected: "https://localhost:8080"}, - {name: "https+insecure-scheme", input: "https+insecure://localhost:8080", expected: "https+insecure://localhost:8080"}, + {name: "https-insecure-scheme", input: "https+insecure://localhost:8080", expected: "https+insecure://localhost:8080"}, {name: "change-default-scheme", input: "localhost:8080", defaultScheme: "https", expected: "https://localhost:8080"}, {name: "change-supported-schemes", input: "localhost:8080", defaultScheme: "tcp", supportedSchemes: []string{"tcp"}, expected: "tcp://localhost:8080"}, {name: "remote-target", input: "https://example.com:8080", expected: "https://example.com:8080"}, diff --git a/ipn/store/awsstore/store_aws.go b/ipn/store/awsstore/store_aws.go index e06e00eb3d3dd..feb86e457805a 100644 --- a/ipn/store/awsstore/store_aws.go +++ b/ipn/store/awsstore/store_aws.go @@ -189,8 +189,7 @@ func (s *awsStore) LoadState() error { ) if err != nil { - var pnf *ssmTypes.ParameterNotFound - if errors.As(err, &pnf) { + if _, ok := errors.AsType[*ssmTypes.ParameterNotFound](err); ok { // Create the parameter as it does not exist yet // and return directly as it is defacto empty return s.persistState() diff --git a/k8s-operator/api.md b/k8s-operator/api.md index 5a60f66e039d0..9101c95ca6e59 100644 --- a/k8s-operator/api.md +++ b/k8s-operator/api.md @@ -483,6 +483,8 @@ _Appears in:_ | Field | Description | Default | Validation | | --- | --- | --- | --- | | `tolerations` _[Toleration](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#toleration-v1-core) array_ | If specified, applies tolerations to the pods deployed by the DNSConfig resource. | | | +| `affinity` _[Affinity](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#affinity-v1-core)_ | If specified, applies affinity rules to the pods deployed by the DNSConfig resource. | | | +| `nodeSelector` _object (keys:string, values:string)_ | If specified, applies node selector rules to the pods deployed by the DNSConfig resource. | | | #### NameserverService diff --git a/k8s-operator/apis/v1alpha1/types_tsdnsconfig.go b/k8s-operator/apis/v1alpha1/types_tsdnsconfig.go index c1a2e7906fcd8..529114c2e1957 100644 --- a/k8s-operator/apis/v1alpha1/types_tsdnsconfig.go +++ b/k8s-operator/apis/v1alpha1/types_tsdnsconfig.go @@ -113,6 +113,12 @@ type NameserverPod struct { // If specified, applies tolerations to the pods deployed by the DNSConfig resource. // +optional Tolerations []corev1.Toleration `json:"tolerations,omitempty"` + // If specified, applies affinity rules to the pods deployed by the DNSConfig resource. + // +optional + Affinity *corev1.Affinity `json:"affinity,omitzero"` + // If specified, applies node selector rules to the pods deployed by the DNSConfig resource. + // +optional + NodeSelector map[string]string `json:"nodeSelector,omitzero"` } type DNSConfigStatus struct { diff --git a/k8s-operator/apis/v1alpha1/zz_generated.deepcopy.go b/k8s-operator/apis/v1alpha1/zz_generated.deepcopy.go index 2528c89f364d6..b401c6d8778f5 100644 --- a/k8s-operator/apis/v1alpha1/zz_generated.deepcopy.go +++ b/k8s-operator/apis/v1alpha1/zz_generated.deepcopy.go @@ -469,6 +469,18 @@ func (in *NameserverPod) DeepCopyInto(out *NameserverPod) { (*in)[i].DeepCopyInto(&(*out)[i]) } } + if in.Affinity != nil { + in, out := &in.Affinity, &out.Affinity + *out = new(corev1.Affinity) + (*in).DeepCopyInto(*out) + } + if in.NodeSelector != nil { + in, out := &in.NodeSelector, &out.NodeSelector + *out = make(map[string]string, len(*in)) + for key, val := range *in { + (*out)[key] = val + } + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new NameserverPod. diff --git a/k8s-operator/reconciler/proxygrouppolicy/proxygrouppolicy.go b/k8s-operator/reconciler/proxygrouppolicy/proxygrouppolicy.go index 0541a5cf3691b..b4c311046bc7c 100644 --- a/k8s-operator/reconciler/proxygrouppolicy/proxygrouppolicy.go +++ b/k8s-operator/reconciler/proxygrouppolicy/proxygrouppolicy.go @@ -24,7 +24,6 @@ import ( "sigs.k8s.io/controller-runtime/pkg/reconcile" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" - "tailscale.com/types/ptr" "tailscale.com/util/set" ) @@ -243,7 +242,7 @@ func (r *Reconciler) generateIngressPolicy(ctx context.Context, namespace string ResourceVersion: policy.ResourceVersion, }, Spec: admr.ValidatingAdmissionPolicySpec{ - FailurePolicy: ptr.To(admr.Fail), + FailurePolicy: new(admr.Fail), MatchConstraints: &admr.MatchResources{ // The operator allows ingress via Ingress resources & Service resources (that use the "tailscale" load // balancer class), so we have two resource rules here with multiple validation expressions that attempt @@ -304,7 +303,7 @@ func (r *Reconciler) generateEgressPolicy(ctx context.Context, namespace string, ResourceVersion: policy.ResourceVersion, }, Spec: admr.ValidatingAdmissionPolicySpec{ - FailurePolicy: ptr.To(admr.Fail), + FailurePolicy: new(admr.Fail), MatchConstraints: &admr.MatchResources{ ResourceRules: []admr.NamedRuleWithOperations{ { diff --git a/k8s-operator/reconciler/proxygrouppolicy/proxygrouppolicy_test.go b/k8s-operator/reconciler/proxygrouppolicy/proxygrouppolicy_test.go index 6710eac7406d6..d5c0b6d353e21 100644 --- a/k8s-operator/reconciler/proxygrouppolicy/proxygrouppolicy_test.go +++ b/k8s-operator/reconciler/proxygrouppolicy/proxygrouppolicy_test.go @@ -30,7 +30,7 @@ func TestReconciler_Reconcile(t *testing.T) { ExpectsError bool }{ { - Name: "single policy, denies all", + Name: "single-policy-denies-all", ExpectedPolicyCount: 2, Request: reconcile.Request{ NamespacedName: types.NamespacedName{ @@ -53,7 +53,7 @@ func TestReconciler_Reconcile(t *testing.T) { }, }, { - Name: "multiple policies merged", + Name: "multiple-policies-merged", ExpectedPolicyCount: 2, Request: reconcile.Request{ NamespacedName: types.NamespacedName{ @@ -89,7 +89,7 @@ func TestReconciler_Reconcile(t *testing.T) { }, }, { - Name: "no policies, no child resources", + Name: "no-policies-no-child-resources", ExpectedPolicyCount: 0, Request: reconcile.Request{ NamespacedName: types.NamespacedName{ diff --git a/k8s-operator/reconciler/tailnet/mocks_test.go b/k8s-operator/reconciler/tailnet/mocks_test.go index 4342556885013..3931e4d33bbb5 100644 --- a/k8s-operator/reconciler/tailnet/mocks_test.go +++ b/k8s-operator/reconciler/tailnet/mocks_test.go @@ -9,7 +9,9 @@ import ( "context" "io" - "tailscale.com/internal/client/tailscale" + "tailscale.com/client/tailscale/v2" + + "tailscale.com/k8s-operator/tsclient" ) type ( @@ -18,28 +20,62 @@ type ( ErrorOnKeys bool ErrorOnServices bool } + + MockDeviceResource struct { + tsclient.DeviceResource + + Error bool + } + + MockKeyResource struct { + tsclient.KeyResource + + Error bool + } + + MockVIPServiceResource struct { + tsclient.VIPServiceResource + + Error bool + } ) -func (m MockTailnetClient) Devices(_ context.Context, _ *tailscale.DeviceFieldsOpts) ([]*tailscale.Device, error) { - if m.ErrorOnDevices { +func (m MockKeyResource) List(_ context.Context, _ bool) ([]tailscale.Key, error) { + if m.Error { return nil, io.EOF } return nil, nil } -func (m MockTailnetClient) Keys(_ context.Context) ([]string, error) { - if m.ErrorOnKeys { +func (m MockDeviceResource) List(_ context.Context, _ ...tailscale.ListDevicesOptions) ([]tailscale.Device, error) { + if m.Error { return nil, io.EOF } return nil, nil } -func (m MockTailnetClient) ListVIPServices(_ context.Context) (*tailscale.VIPServiceList, error) { - if m.ErrorOnServices { +func (m MockVIPServiceResource) List(_ context.Context) ([]tailscale.VIPService, error) { + if m.Error { return nil, io.EOF } return nil, nil } + +func (m MockTailnetClient) Devices() tsclient.DeviceResource { + return MockDeviceResource{Error: m.ErrorOnDevices} +} + +func (m MockTailnetClient) Keys() tsclient.KeyResource { + return MockKeyResource{Error: m.ErrorOnKeys} +} + +func (m MockTailnetClient) VIPServices() tsclient.VIPServiceResource { + return MockVIPServiceResource{Error: m.ErrorOnServices} +} + +func (m MockTailnetClient) LoginURL() string { + return "" +} diff --git a/k8s-operator/reconciler/tailnet/tailnet.go b/k8s-operator/reconciler/tailnet/tailnet.go index 2e7004b698c93..e30bb21702e39 100644 --- a/k8s-operator/reconciler/tailnet/tailnet.go +++ b/k8s-operator/reconciler/tailnet/tailnet.go @@ -12,12 +12,11 @@ import ( "context" "errors" "fmt" + "net/url" "sync" "time" "go.uber.org/zap" - "golang.org/x/oauth2" - "golang.org/x/oauth2/clientcredentials" corev1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -26,12 +25,13 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/manager" "sigs.k8s.io/controller-runtime/pkg/reconcile" + "tailscale.com/client/tailscale/v2" - "tailscale.com/internal/client/tailscale" "tailscale.com/ipn" operatorutils "tailscale.com/k8s-operator" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" "tailscale.com/k8s-operator/reconciler" + "tailscale.com/k8s-operator/tsclient" "tailscale.com/kube/kubetypes" "tailscale.com/tstime" "tailscale.com/util/clientmetric" @@ -47,7 +47,8 @@ type ( tailscaleNamespace string clock tstime.Clock logger *zap.SugaredLogger - clientFunc func(*tsapi.Tailnet, *corev1.Secret) TailscaleClient + clientFunc func(*tsapi.Tailnet, *corev1.Secret) tsclient.Client + registry ClientRegistry // Metrics related fields mu sync.Mutex @@ -68,14 +69,18 @@ type ( Logger *zap.SugaredLogger // ClientFunc is a function that takes tailscale credentials and returns an implementation for the Tailscale // HTTP API. This should generally be nil unless needed for testing. - ClientFunc func(*tsapi.Tailnet, *corev1.Secret) TailscaleClient + ClientFunc func(*tsapi.Tailnet, *corev1.Secret) tsclient.Client + // Registry is used to store and share initialized tailscale clients for use by other reconcilers. + Registry ClientRegistry } - // The TailscaleClient interface describes types that interact with the Tailscale HTTP API. - TailscaleClient interface { - Devices(context.Context, *tailscale.DeviceFieldsOpts) ([]*tailscale.Device, error) - Keys(ctx context.Context) ([]string, error) - ListVIPServices(ctx context.Context) (*tailscale.VIPServiceList, error) + // The ClientRegistry interface describes types that can store initialized tailscale clients for use by other + // reconcilers. + ClientRegistry interface { + // Add should store the given tsclient.Client implementation for a specified tailnet. + Add(tailnet string, client tsclient.Client, ready bool) + // Remove should remove any tsclient.Client implementation for a specified tailnet. + Remove(tailnet string) } ) @@ -90,6 +95,7 @@ func NewReconciler(options ReconcilerOptions) *Reconciler { clock: options.Clock, logger: options.Logger.Named(reconcilerName), clientFunc: options.ClientFunc, + registry: options.Registry, } } @@ -137,6 +143,7 @@ func (r *Reconciler) delete(ctx context.Context, tailnet *tsapi.Tailnet) (reconc r.tailnets.Remove(tailnet.UID) r.mu.Unlock() gaugeTailnetResources.Set(int64(r.tailnets.Len())) + r.registry.Remove(tailnet.Name) return reconcile.Result{}, nil } @@ -193,11 +200,16 @@ func (r *Reconciler) createOrUpdate(ctx context.Context, tailnet *tsapi.Tailnet) return reconcile.Result{RequeueAfter: time.Minute / 2}, nil } - tsClient := r.createClient(ctx, tailnet, &secret) + tsClient, err := r.createClient(tailnet, &secret) + if err != nil { + return reconcile.Result{}, fmt.Errorf("failed to create tailnet client: %w", err) + } // Second, we ensure the OAuth credentials supplied in the secret are valid and have the required scopes to access // the various API endpoints required by the operator. if ok := r.ensurePermissions(ctx, tsClient, tailnet); !ok { + r.registry.Add(tailnet.Name, tsClient, false) + if err = r.Status().Update(ctx, tailnet); err != nil { return reconcile.Result{}, fmt.Errorf("failed to update Tailnet status for %q: %w", tailnet.Name, err) } @@ -226,6 +238,8 @@ func (r *Reconciler) createOrUpdate(ctx context.Context, tailnet *tsapi.Tailnet) return reconcile.Result{}, fmt.Errorf("failed to add finalizer to Tailnet %q: %w", tailnet.Name, err) } + r.registry.Add(tailnet.Name, tsClient, true) + return reconcile.Result{}, nil } @@ -235,9 +249,9 @@ const ( clientSecretKey = "client_secret" ) -func (r *Reconciler) createClient(ctx context.Context, tailnet *tsapi.Tailnet, secret *corev1.Secret) TailscaleClient { +func (r *Reconciler) createClient(tailnet *tsapi.Tailnet, secret *corev1.Secret) (tsclient.Client, error) { if r.clientFunc != nil { - return r.clientFunc(tailnet, secret) + return r.clientFunc(tailnet, secret), nil } baseURL := ipn.DefaultControlURL @@ -245,38 +259,36 @@ func (r *Reconciler) createClient(ctx context.Context, tailnet *tsapi.Tailnet, s baseURL = tailnet.Spec.LoginURL } - credentials := clientcredentials.Config{ - ClientID: string(secret.Data[clientIDKey]), - ClientSecret: string(secret.Data[clientSecretKey]), - TokenURL: baseURL + "/api/v2/oauth/token", + base, err := url.Parse(baseURL) + if err != nil { + return nil, fmt.Errorf("failed to parse base URL %q: %w", baseURL, err) } - source := credentials.TokenSource(ctx) - httpClient := oauth2.NewClient(ctx, source) - - tsClient := tailscale.NewClient("-", nil) - tsClient.UserAgent = "tailscale-k8s-operator" - tsClient.HTTPClient = httpClient - tsClient.BaseURL = baseURL - - return tsClient + return tsclient.Wrap(&tailscale.Client{ + BaseURL: base, + UserAgent: "tailscale-k8s-operator", + Auth: &tailscale.OAuth{ + ClientID: string(secret.Data[clientIDKey]), + ClientSecret: string(secret.Data[clientSecretKey]), + }, + }), nil } -func (r *Reconciler) ensurePermissions(ctx context.Context, tsClient TailscaleClient, tailnet *tsapi.Tailnet) bool { +func (r *Reconciler) ensurePermissions(ctx context.Context, tsClient tsclient.Client, tailnet *tsapi.Tailnet) bool { // Perform basic list requests here to confirm that the OAuth credentials referenced on the Tailnet resource // can perform the basic operations required for the operator to function. This has a caveat of only performing // read actions, as we don't want to create arbitrary keys and VIP services. However, it will catch when a user // has completely forgotten an entire scope that's required. var errs error - if _, err := tsClient.Devices(ctx, nil); err != nil { + if _, err := tsClient.Devices().List(ctx); err != nil { errs = errors.Join(errs, fmt.Errorf("failed to list devices: %w", err)) } - if _, err := tsClient.Keys(ctx); err != nil { + if _, err := tsClient.Keys().List(ctx, false); err != nil { errs = errors.Join(errs, fmt.Errorf("failed to list auth keys: %w", err)) } - if _, err := tsClient.ListVIPServices(ctx); err != nil { + if _, err := tsClient.VIPServices().List(ctx); err != nil { errs = errors.Join(errs, fmt.Errorf("failed to list tailscale services: %w", err)) } diff --git a/k8s-operator/reconciler/tailnet/tailnet_test.go b/k8s-operator/reconciler/tailnet/tailnet_test.go index 0ed2ca598d720..513ed7b84dcd1 100644 --- a/k8s-operator/reconciler/tailnet/tailnet_test.go +++ b/k8s-operator/reconciler/tailnet/tailnet_test.go @@ -18,6 +18,7 @@ import ( tsapi "tailscale.com/k8s-operator/apis/v1alpha1" "tailscale.com/k8s-operator/reconciler/tailnet" + "tailscale.com/k8s-operator/tsclient" "tailscale.com/tstest" ) @@ -36,10 +37,10 @@ func TestReconciler_Reconcile(t *testing.T) { Secret *corev1.Secret ExpectsError bool ExpectedConditions []metav1.Condition - ClientFunc func(*tsapi.Tailnet, *corev1.Secret) tailnet.TailscaleClient + ClientFunc func(*tsapi.Tailnet, *corev1.Secret) tsclient.Client }{ { - Name: "ignores unknown tailnet requests", + Name: "ignores-unknown-tailnet-requests", Request: reconcile.Request{ NamespacedName: types.NamespacedName{ Name: "test", @@ -47,7 +48,7 @@ func TestReconciler_Reconcile(t *testing.T) { }, }, { - Name: "invalid status for missing secret", + Name: "invalid-status-missing-secret", Request: reconcile.Request{ NamespacedName: types.NamespacedName{ Name: "test", @@ -73,7 +74,7 @@ func TestReconciler_Reconcile(t *testing.T) { }, }, { - Name: "invalid status for empty secret", + Name: "invalid-status-empty-secret", Request: reconcile.Request{ NamespacedName: types.NamespacedName{ Name: "test", @@ -105,7 +106,7 @@ func TestReconciler_Reconcile(t *testing.T) { }, }, { - Name: "invalid status for missing client id", + Name: "invalid-status-missing-client-id", Request: reconcile.Request{ NamespacedName: types.NamespacedName{ Name: "test", @@ -140,7 +141,7 @@ func TestReconciler_Reconcile(t *testing.T) { }, }, { - Name: "invalid status for missing client secret", + Name: "invalid-status-missing-client-secret", Request: reconcile.Request{ NamespacedName: types.NamespacedName{ Name: "test", @@ -175,7 +176,7 @@ func TestReconciler_Reconcile(t *testing.T) { }, }, { - Name: "invalid status for bad devices scope", + Name: "invalid-status-bad-devices-scope", Request: reconcile.Request{ NamespacedName: types.NamespacedName{ Name: "test", @@ -201,7 +202,7 @@ func TestReconciler_Reconcile(t *testing.T) { "client_secret": []byte("test"), }, }, - ClientFunc: func(_ *tsapi.Tailnet, _ *corev1.Secret) tailnet.TailscaleClient { + ClientFunc: func(_ *tsapi.Tailnet, _ *corev1.Secret) tsclient.Client { return &MockTailnetClient{ErrorOnDevices: true} }, ExpectedConditions: []metav1.Condition{ @@ -214,7 +215,7 @@ func TestReconciler_Reconcile(t *testing.T) { }, }, { - Name: "invalid status for bad services scope", + Name: "invalid-status-bad-services-scope", Request: reconcile.Request{ NamespacedName: types.NamespacedName{ Name: "test", @@ -240,7 +241,7 @@ func TestReconciler_Reconcile(t *testing.T) { "client_secret": []byte("test"), }, }, - ClientFunc: func(_ *tsapi.Tailnet, _ *corev1.Secret) tailnet.TailscaleClient { + ClientFunc: func(_ *tsapi.Tailnet, _ *corev1.Secret) tsclient.Client { return &MockTailnetClient{ErrorOnServices: true} }, ExpectedConditions: []metav1.Condition{ @@ -253,7 +254,7 @@ func TestReconciler_Reconcile(t *testing.T) { }, }, { - Name: "invalid status for bad keys scope", + Name: "invalid-status-bad-keys-scope", Request: reconcile.Request{ NamespacedName: types.NamespacedName{ Name: "test", @@ -279,7 +280,7 @@ func TestReconciler_Reconcile(t *testing.T) { "client_secret": []byte("test"), }, }, - ClientFunc: func(_ *tsapi.Tailnet, _ *corev1.Secret) tailnet.TailscaleClient { + ClientFunc: func(_ *tsapi.Tailnet, _ *corev1.Secret) tsclient.Client { return &MockTailnetClient{ErrorOnKeys: true} }, ExpectedConditions: []metav1.Condition{ @@ -292,7 +293,7 @@ func TestReconciler_Reconcile(t *testing.T) { }, }, { - Name: "ready when valid and scopes are correct", + Name: "ready-valid-scopes-correct", Request: reconcile.Request{ NamespacedName: types.NamespacedName{ Name: "default", @@ -318,7 +319,7 @@ func TestReconciler_Reconcile(t *testing.T) { "client_secret": []byte("test"), }, }, - ClientFunc: func(_ *tsapi.Tailnet, _ *corev1.Secret) tailnet.TailscaleClient { + ClientFunc: func(_ *tsapi.Tailnet, _ *corev1.Secret) tsclient.Client { return &MockTailnetClient{} }, ExpectedConditions: []metav1.Condition{ @@ -349,6 +350,7 @@ func TestReconciler_Reconcile(t *testing.T) { Logger: logger.Sugar(), ClientFunc: tc.ClientFunc, TailscaleNamespace: "tailscale", + Registry: tsclient.NewProvider(nil), } reconciler := tailnet.NewReconciler(opts) diff --git a/k8s-operator/sessionrecording/spdy/frame.go b/k8s-operator/sessionrecording/spdy/frame.go index 7087db3c32166..3ca661e0b6a2a 100644 --- a/k8s-operator/sessionrecording/spdy/frame.go +++ b/k8s-operator/sessionrecording/spdy/frame.go @@ -211,7 +211,7 @@ func parseHeaders(decompressor io.Reader, log *zap.SugaredLogger) (http.Header, return nil, fmt.Errorf("error determining num headers: %v", err) } h := make(http.Header, numHeaders) - for i := uint32(0); i < numHeaders; i++ { + for range numHeaders { name, err := readLenBytes() if err != nil { return nil, err @@ -224,7 +224,7 @@ func parseHeaders(decompressor io.Reader, log *zap.SugaredLogger) (http.Header, if err != nil { return nil, fmt.Errorf("error reading header data: %w", err) } - for _, v := range bytes.Split(val, headerSep) { + for v := range bytes.SplitSeq(val, headerSep) { h.Add(ns, string(v)) } } diff --git a/k8s-operator/sessionrecording/ws/conn.go b/k8s-operator/sessionrecording/ws/conn.go index 4762630ca7522..ed0ecc7ac7f0d 100644 --- a/k8s-operator/sessionrecording/ws/conn.go +++ b/k8s-operator/sessionrecording/ws/conn.go @@ -147,88 +147,12 @@ func (c *conn) Read(b []byte) (int, error) { return 0, nil } - // TODO(tomhjp): If we get multiple frames in a single Read with different - // types, we may parse the second frame with the wrong type. - typ := messageType(opcode(b)) - if (typ == noOpcode && c.readMsgIsIncomplete()) || c.readBufHasIncompleteFragment() { // subsequent fragment - if typ, err = c.curReadMsgType(); err != nil { - return 0, err - } - } - - // A control message can not be fragmented and we are not interested in - // these messages. Just return. - // TODO(tomhjp): If we get multiple frames in a single Read, we may skip - // some non-control messages. - if isControlMessage(typ) { - return n, nil - } - - // The only data message type that Kubernetes supports is binary message. - // If we received another message type, return and let the API server close the connection. - // https://github.com/kubernetes/client-go/blob/release-1.30/tools/remotecommand/websocket.go#L281 - if typ != binaryMessage { - c.log.Infof("[unexpected] received a data message with a type that is not binary message type %v", typ) - return n, nil - } - if _, err := c.readBuf.Write(b[:n]); err != nil { return 0, fmt.Errorf("[unexpected] error writing message contents to read buffer: %w", err) } - for c.readBuf.Len() != 0 { - readMsg := &message{typ: typ} // start a new message... - // ... or pick up an already started one if the previous fragment was not final. - if c.readMsgIsIncomplete() { - readMsg = c.currentReadMsg - } - - ok, err := readMsg.Parse(c.readBuf.Bytes(), c.log) - if err != nil { - return 0, fmt.Errorf("error parsing message: %v", err) - } - if !ok { // incomplete fragment - return n, nil - } - c.readBuf.Next(len(readMsg.raw)) - - if readMsg.isFinalized && !c.readMsgIsIncomplete() { - // we want to send stream resize messages for terminal sessions - // Stream IDs for websocket streams are static. - // https://github.com/kubernetes/client-go/blob/v0.30.0-rc.1/tools/remotecommand/websocket.go#L218 - if readMsg.streamID.Load() == remotecommand.StreamResize && c.hasTerm { - var msg tsrecorder.ResizeMsg - if err = json.Unmarshal(readMsg.payload, &msg); err != nil { - return 0, fmt.Errorf("error umarshalling resize message: %w", err) - } - - c.ch.Width = msg.Width - c.ch.Height = msg.Height - - var isInitialResize bool - c.writeCastHeaderOnce.Do(func() { - isInitialResize = true - // If this is a session with a terminal attached, - // we must wait for the terminal width and - // height to be parsed from a resize message - // before sending CastHeader, else tsrecorder - // will not be able to play this recording. - err = c.rec.WriteCastHeader(c.ch) - close(c.initialCastHeaderSent) - }) - if err != nil { - return 0, fmt.Errorf("error writing CastHeader: %w", err) - } - - if !isInitialResize { - if err := c.rec.WriteResize(msg.Height, msg.Width); err != nil { - return 0, fmt.Errorf("error writing resize message: %w", err) - } - } - } - } - - c.currentReadMsg = readMsg + if _, err := c.processFrames(&c.readBuf, &c.currentReadMsg); err != nil { + return 0, err } return n, nil @@ -245,64 +169,21 @@ func (c *conn) Write(b []byte) (int, error) { return 0, nil } - typ := messageType(opcode(b)) - // If we are in process of parsing a message fragment, the received - // bytes are not structured as a message fragment and can not be used to - // determine a message fragment. - if c.writeBufHasIncompleteFragment() { // buffer contains previous incomplete fragment - var err error - if typ, err = c.curWriteMsgType(); err != nil { - return 0, err - } - } - - if isControlMessage(typ) { - return c.Conn.Write(b) - } - - writeMsg := &message{typ: typ} // start a new message... - // ... or continue the existing one if it has not been finalized. - if c.writeMsgIsIncomplete() || c.writeBufHasIncompleteFragment() { - writeMsg = c.currentWriteMsg - } - if _, err := c.writeBuf.Write(b); err != nil { c.log.Errorf("write: error writing to write buf: %v", err) return 0, fmt.Errorf("[unexpected] error writing to internal write buffer: %w", err) } - ok, err := writeMsg.Parse(c.writeBuf.Bytes(), c.log) + raw, err := c.processFrames(&c.writeBuf, &c.currentWriteMsg) if err != nil { - c.log.Errorf("write: parsing a message errored: %v", err) - return 0, fmt.Errorf("write: error parsing message: %v", err) - } - - c.currentWriteMsg = writeMsg - if !ok { // incomplete fragment - return len(b), nil + return 0, err } - - c.writeBuf.Next(len(writeMsg.raw)) // advance frame - - if len(writeMsg.payload) != 0 && writeMsg.isFinalized { - if writeMsg.streamID.Load() == remotecommand.StreamStdOut || writeMsg.streamID.Load() == remotecommand.StreamStdErr { - // we must wait for confirmation that the initial cast header was sent before proceeding with any more writes - select { - case <-c.ctx.Done(): - return 0, c.ctx.Err() - case <-c.initialCastHeaderSent: - if err := c.rec.WriteOutput(writeMsg.payload); err != nil { - return 0, fmt.Errorf("error writing message to recorder: %w", err) - } - } + if len(raw) > 0 { + if _, err := c.Conn.Write(raw); err != nil { + return 0, err } } - _, err = c.Conn.Write(c.currentWriteMsg.raw) - if err != nil { - c.log.Errorf("write: error writing to conn: %v", err) - } - return len(b), nil } @@ -318,48 +199,125 @@ func (c *conn) Close() error { return errors.Join(connCloseErr, recCloseErr) } -// writeBufHasIncompleteFragment returns true if the latest data message -// fragment written to the connection was incomplete and the following write -// must be the remaining payload bytes of that fragment. -func (c *conn) writeBufHasIncompleteFragment() bool { - return c.writeBuf.Len() != 0 +// handleData records a finalized data message to the session recorder. +// It handles resize messages (updating terminal dimensions and writing the +// CastHeader on the first one) and stdout/stderr messages (recording output). +// Other stream IDs (stdin, error) are ignored. +func (c *conn) handleData(msg *message) error { + switch msg.streamID.Load() { + case remotecommand.StreamResize: + if !c.hasTerm { + return nil + } + var rm tsrecorder.ResizeMsg + if err := json.Unmarshal(msg.payload, &rm); err != nil { + return fmt.Errorf("error unmarshalling resize message: %w", err) + } + c.ch.Width = rm.Width + c.ch.Height = rm.Height + + // The first resize writes the CastHeader and unblocks output recording. + var headerErr error + var isInitialResize bool + c.writeCastHeaderOnce.Do(func() { + isInitialResize = true + headerErr = c.rec.WriteCastHeader(c.ch) + close(c.initialCastHeaderSent) + }) + if headerErr != nil { + return fmt.Errorf("error writing CastHeader: %w", headerErr) + } + if !isInitialResize { + if err := c.rec.WriteResize(rm.Height, rm.Width); err != nil { + return fmt.Errorf("error writing resize message: %w", err) + } + } + case remotecommand.StreamStdOut, remotecommand.StreamStdErr: + // Wait for the CastHeader before recording any output. + select { + case <-c.ctx.Done(): + return c.ctx.Err() + case <-c.initialCastHeaderSent: + if err := c.rec.WriteOutput(msg.payload); err != nil { + return fmt.Errorf("error writing message to recorder: %w", err) + } + } + } + return nil } -// readBufHasIncompleteFragment returns true if the latest data message -// fragment read from the connection was incomplete and the following read -// must be the remaining payload bytes of that fragment. -func (c *conn) readBufHasIncompleteFragment() bool { - return c.readBuf.Len() != 0 -} +// processFrames drains complete WebSocket frames from buf, recording session +// data via handleData for finalized binary messages. It returns the raw bytes +// of every consumed frame so the Write path can forward them to the underlying +// connection. Incomplete frames are left in buf for the next call. +// +// Control frames are consumed whole without inspection. Non-binary data frames +// are unexpected (k8s only uses binary) and cause the buffer to be discarded. +func (c *conn) processFrames( + buf *bytes.Buffer, + curMsg **message, +) ([]byte, error) { + var raw []byte + for buf.Len() != 0 { + b := buf.Bytes() + if len(b) < 2 { + return raw, nil + } -// writeMsgIsIncomplete returns true if the latest WebSocket message written to -// the connection was fragmented and the next data message fragment written to -// the connection must be a fragment of that message. -// https://www.rfc-editor.org/rfc/rfc6455#section-5.4 -func (c *conn) writeMsgIsIncomplete() bool { - return c.currentWriteMsg != nil && !c.currentWriteMsg.isFinalized -} + // Continuation frames (opcode 0) inherit the type of the in-progress message. + typ := messageType(opcode(b)) + if typ == noOpcode && *curMsg != nil { + typ = (*curMsg).typ + } -// readMsgIsIncomplete returns true if the latest WebSocket message written to -// the connection was fragmented and the next data message fragment written to -// the connection must be a fragment of that message. -// https://www.rfc-editor.org/rfc/rfc6455#section-5.4 -func (c *conn) readMsgIsIncomplete() bool { - return c.currentReadMsg != nil && !c.currentReadMsg.isFinalized -} + // Control frames: pass through without inspection. + if isControlMessage(typ) { + maskSet := isMasked(b) + payloadLen, payloadOffset, _, err := fragmentDimensions(b, maskSet) + if err != nil { + return nil, fmt.Errorf("error parsing control frame: %w", err) + } + frameLen := int(payloadOffset + payloadLen) + if len(b) < frameLen { + return raw, nil // incomplete control frame + } + raw = append(raw, b[:frameLen]...) + buf.Next(frameLen) + continue + } -func (c *conn) curReadMsgType() (messageType, error) { - if c.currentReadMsg != nil { - return c.currentReadMsg.typ, nil - } - return 0, errors.New("[unexpected] attempted to determine type for nil message") -} + // k8s remotecommand only uses binary data messages. + if typ != binaryMessage { + c.log.Infof("[unexpected] received a data message with a type that is not binary message type %v", typ) + buf.Reset() + return raw, nil + } -func (c *conn) curWriteMsgType() (messageType, error) { - if c.currentWriteMsg != nil { - return c.currentWriteMsg.typ, nil + // Continue a fragmented message or start a new one. + msg := &message{typ: typ} + if *curMsg != nil && !(*curMsg).isFinalized { + msg = *curMsg + } + + ok, err := msg.Parse(b, c.log) + if err != nil { + return nil, fmt.Errorf("error parsing message: %w", err) + } + if !ok { + *curMsg = msg + return raw, nil // incomplete fragment, wait for more bytes + } + buf.Next(len(msg.raw)) + *curMsg = msg + + raw = append(raw, msg.raw...) + if msg.isFinalized && len(msg.payload) > 0 { + if err := c.handleData(msg); err != nil { + return nil, err + } + } } - return 0, errors.New("[unexpected] attempted to determine type for nil message") + return raw, nil } // opcode reads the websocket message opcode that denotes the message type. diff --git a/k8s-operator/sessionrecording/ws/conn_test.go b/k8s-operator/sessionrecording/ws/conn_test.go index 0b4353698cd9f..ea9aca19296d5 100644 --- a/k8s-operator/sessionrecording/ws/conn_test.go +++ b/k8s-operator/sessionrecording/ws/conn_test.go @@ -37,6 +37,17 @@ func Test_conn_Read(t *testing.T) { wantCastHeaderHeight int wantRecorded []byte }{ + // Empty final continuation frame after a resize frame. + { + name: "continuation_frame_with_empty_payload", + inputs: [][]byte{ + append([]byte{0x02, lenResizeMsgPayload}, testResizeMsg...), + {0x80, 0x00}, + }, + wantRecorded: fakes.AsciinemaCastHeaderMsg(t, 10, 20), + wantCastHeaderWidth: 10, + wantCastHeaderHeight: 20, + }, { name: "single_read_control_message", inputs: [][]byte{{0x88, 0x0}}, @@ -58,6 +69,19 @@ func Test_conn_Read(t *testing.T) { wantCastHeaderWidth: 10, wantCastHeaderHeight: 20, }, + { + // A control frame (close) followed by a resize data frame in + // a single Read. Without the frame loop, the close frame + // would cause the data frame to be skipped. + name: "control_then_data_in_one_read", + inputs: [][]byte{ + // close frame (0x88, len 0), then resize data frame + append([]byte{0x88, 0x00, 0x82, lenResizeMsgPayload}, testResizeMsg...), + }, + wantRecorded: fakes.AsciinemaCastHeaderMsg(t, 10, 20), + wantCastHeaderWidth: 10, + wantCastHeaderHeight: 20, + }, { name: "resize_data_frame_two_in_one_read", inputs: [][]byte{ @@ -156,6 +180,26 @@ func Test_conn_Write(t *testing.T) { wantRecorded []byte hasTerm bool }{ + // Empty final continuation frame; stream ID already set by + // the initial fragment. + { + name: "continuation_frame_with_empty_payload", + inputs: [][]byte{ + {0x02, 0x03, 0x01, 0x07, 0x08}, + {0x80, 0x00}, + }, + wantForwarded: []byte{0x02, 0x03, 0x01, 0x07, 0x08, 0x80, 0x00}, + wantRecorded: fakes.CastLine(t, []byte{0x07, 0x08}, cl), + }, + // Same as above but both fragments land in one Write call. + { + name: "continuation_frame_with_empty_payload_single_write", + inputs: [][]byte{ + {0x02, 0x03, 0x01, 0x07, 0x08, 0x80, 0x00}, + }, + wantForwarded: []byte{0x02, 0x03, 0x01, 0x07, 0x08, 0x80, 0x00}, + wantRecorded: fakes.CastLine(t, []byte{0x07, 0x08}, cl), + }, { name: "single_write_control_frame", inputs: [][]byte{{0x88, 0x0}}, @@ -203,6 +247,38 @@ func Test_conn_Write(t *testing.T) { wantRecorded: fakes.CastLine(t, []byte{0x7, 0x8, 0x1, 0x2, 0x3, 0x4, 0x5}, cl), hasTerm: true, }, + { + // Two complete WebSocket frames coalesced into a single + // Write() call: a stdout binary frame followed by a close + // frame. Without a loop in the Write path, the close frame + // gets stranded in writeBuf and misinterpreted on the next + // Write. + name: "two_frames_in_one_write_data_then_close", + inputs: [][]byte{ + // binary frame (opcode 0x2, FIN set = 0x82), payload len 3, + // stream ID 1 (stdout), two data bytes, + // then close frame (opcode 0x8, FIN set = 0x88), payload len 0 + {0x82, 0x03, 0x01, 0x07, 0x08, 0x88, 0x00}, + }, + wantForwarded: []byte{0x82, 0x03, 0x01, 0x07, 0x08, 0x88, 0x00}, + wantRecorded: fakes.CastLine(t, []byte{0x07, 0x08}, cl), + }, + { + // Two complete stdout data frames in one Write() call. + // Mirrors the "resize_data_frame_two_in_one_read" test + // for the Read path. + name: "two_data_frames_in_one_write", + inputs: [][]byte{ + // first: binary frame, payload len 3, stdout stream, two data bytes + // second: binary frame, payload len 3, stdout stream, two different data bytes + {0x82, 0x03, 0x01, 0x07, 0x08, 0x82, 0x03, 0x01, 0x09, 0x0a}, + }, + wantForwarded: []byte{0x82, 0x03, 0x01, 0x07, 0x08, 0x82, 0x03, 0x01, 0x09, 0x0a}, + wantRecorded: append( + fakes.CastLine(t, []byte{0x07, 0x08}, cl), + fakes.CastLine(t, []byte{0x09, 0x0a}, cl)..., + ), + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -254,12 +330,20 @@ func Test_conn_ReadRand(t *testing.T) { if err != nil { t.Fatalf("error creating a test logger: %v", err) } + cl := tstest.NewClock(tstest.ClockOpts{}) + sr := &fakes.TestSessionRecorder{} + rec := tsrecorder.New(sr, cl, cl.Now(), true, zl.Sugar()) for i := range 100 { tc := &fakes.TestConn{} tc.ResetReadBuf() + headerSent := make(chan struct{}) + close(headerSent) // pre-close so handleData doesn't block c := &conn{ - Conn: tc, - log: zl.Sugar(), + Conn: tc, + log: zl.Sugar(), + ctx: context.Background(), + rec: rec, + initialCastHeaderSent: headerSent, } bb := fakes.RandomBytes(t) for j, input := range bb { diff --git a/k8s-operator/sessionrecording/ws/message.go b/k8s-operator/sessionrecording/ws/message.go index 36359996a7c12..47177ef1977fd 100644 --- a/k8s-operator/sessionrecording/ws/message.go +++ b/k8s-operator/sessionrecording/ws/message.go @@ -99,19 +99,19 @@ func (msg *message) Parse(b []byte, log *zap.SugaredLogger) (bool, error) { } isInitialFragment := len(msg.raw) == 0 - msg.isFinalized = isFinalFragment(b) - + finalized := isFinalFragment(b) maskSet := isMasked(b) payloadLength, payloadOffset, maskOffset, err := fragmentDimensions(b, maskSet) if err != nil { return false, fmt.Errorf("error determining payload length: %w", err) } - log.Debugf("parse: parsing a message fragment with payload length: %d payload offset: %d maskOffset: %d mask set: %t, is finalized: %t, is initial fragment: %t", payloadLength, payloadOffset, maskOffset, maskSet, msg.isFinalized, isInitialFragment) + log.Debugf("parse: parsing a message fragment with payload length: %d payload offset: %d maskOffset: %d mask set: %t, is finalized: %t, is initial fragment: %t", payloadLength, payloadOffset, maskOffset, maskSet, finalized, isInitialFragment) if len(b) < int(payloadOffset+payloadLength) { // incomplete fragment return false, nil } + msg.isFinalized = finalized // TODO (irbekrm): perhaps only do this extra allocation if we know we // will need to unmask? msg.raw = make([]byte, int(payloadOffset)+int(payloadLength)) @@ -136,6 +136,13 @@ func (msg *message) Parse(b []byte, log *zap.SugaredLogger) (bool, error) { // message payload. // https://github.com/kubernetes/apimachinery/commit/73d12d09c5be8703587b5127416eb83dc3b7e182#diff-291f96e8632d04d2d20f5fb00f6b323492670570d65434e8eac90c7a442d13bdR23-R36 if len(msgPayload) == 0 { + if !isInitialFragment { + // Continuation frame with zero payload. The stream ID is + // already known from the initial fragment, so this is not + // fatal, just unusual. + log.Infof("[unexpected] received a continuation fragment with no payload") + return true, nil + } return false, errors.New("[unexpected] received a message fragment with no stream ID") } diff --git a/k8s-operator/tsclient/client.go b/k8s-operator/tsclient/client.go new file mode 100644 index 0000000000000..e90e9d8495ac7 --- /dev/null +++ b/k8s-operator/tsclient/client.go @@ -0,0 +1,83 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// Package tsclient provides a mockable wrapper around the tailscale-client-go-v2 package for use by the Kubernetes +// operator. It also contains the Provider type used to manage multiple instances of tailscale clients for different +// tailnets. +package tsclient + +import ( + "context" + + "tailscale.com/client/tailscale/v2" +) + +type ( + // The Client interface describes types that interact with the Tailscale API. + Client interface { + // LoginURL should return the url of the Tailscale control plane. + LoginURL() string + // Devices should return a DeviceResource implementation used to interact with the devices API. + Devices() DeviceResource + // Keys should return a KeyResource implementation used to interact with the keys API. + Keys() KeyResource + // VIPServices should return a VIPServiceResource implementation used to interact with the VIP services API. + VIPServices() VIPServiceResource + } + + // The DeviceResource interface describes types that expose device related API endpoints. + DeviceResource interface { + // Delete should delete a device with a matching id. + Delete(ctx context.Context, id string) error + // List should return all devices based on the specified options. + List(ctx context.Context, opts ...tailscale.ListDevicesOptions) ([]tailscale.Device, error) + // Get should return the device with the matching identifier. + Get(ctx context.Context, id string) (*tailscale.Device, error) + } + + // The KeyResource interface describes types that expose key related API endpoints. + KeyResource interface { + // CreateAuthKey should create and return a new auth key used to authenticate a device. + CreateAuthKey(ctx context.Context, ckr tailscale.CreateKeyRequest) (*tailscale.Key, error) + // List should return keys created by the caller or all keys if the provided boolean is set to true. + List(ctx context.Context, all bool) ([]tailscale.Key, error) + } + + // The VIPServiceResource interface describes types that expose vip service related API endpoints. + VIPServiceResource interface { + // List should return all existing vip services within the tailnet. + List(ctx context.Context) ([]tailscale.VIPService, error) + // Delete should remove a named service from the tailnet. + Delete(ctx context.Context, name string) error + // Get should return the vip service associated with the given name. + Get(ctx context.Context, name string) (*tailscale.VIPService, error) + // CreateOrUpdate should update the provided vip service, creating it if it does not exist. + CreateOrUpdate(ctx context.Context, svc tailscale.VIPService) error + } + + clientWrapper struct { + loginURL string + client *tailscale.Client + } +) + +// Wrap converts a given tailscale.Client into a Client. +func Wrap(client *tailscale.Client) Client { + return &clientWrapper{client: client, loginURL: client.BaseURL.String()} +} + +func (c *clientWrapper) Devices() DeviceResource { + return c.client.Devices() +} + +func (c *clientWrapper) Keys() KeyResource { + return c.client.Keys() +} + +func (c *clientWrapper) VIPServices() VIPServiceResource { + return c.client.VIPServices() +} + +func (c *clientWrapper) LoginURL() string { + return c.loginURL +} diff --git a/k8s-operator/tsclient/provider.go b/k8s-operator/tsclient/provider.go new file mode 100644 index 0000000000000..613cebec3e497 --- /dev/null +++ b/k8s-operator/tsclient/provider.go @@ -0,0 +1,77 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package tsclient + +import ( + "errors" + "fmt" + "sync" +) + +type ( + // The Provider type is used to manage multiple Client implementations for different tailnets. + Provider struct { + defaultClient Client + mu sync.RWMutex + clients map[string]Client + readiness map[string]bool + } +) + +var ( + // ErrClientNotFound is the error given when calling Provider.For with a tailnet that has not yet been registered + // with the provider. + ErrClientNotFound = errors.New("client not found") + // ErrNotReady is the error given when calling Provider.For with a tailnet that has not yet been declared as + // ready to use by the operator. + ErrNotReady = errors.New("tailnet not ready") +) + +// NewProvider returns a new instance of the Provider type that uses the given Client implementation as the default +// client. This client will be given when calling Provider.For with a blank tailnet name. +func NewProvider(defaultClient Client) *Provider { + return &Provider{ + defaultClient: defaultClient, + clients: make(map[string]Client), + readiness: make(map[string]bool), + } +} + +// Add a Client implementation for a given tailnet. +func (p *Provider) Add(tailnet string, client Client, ready bool) { + p.mu.Lock() + defer p.mu.Unlock() + + p.clients[tailnet] = client + p.readiness[tailnet] = ready +} + +// Remove the Client implementation associated with the given tailnet. +func (p *Provider) Remove(tailnet string) { + p.mu.Lock() + defer p.mu.Unlock() + + delete(p.clients, tailnet) +} + +// For returns a Client implementation associated with the given tailnet. Returns ErrClientNotFound if the given +// tailnet does not exist. Use a blank tailnet name to obtain the default Client. +func (p *Provider) For(tailnet string) (Client, error) { + if tailnet == "" { + return p.defaultClient, nil + } + + p.mu.RLock() + defer p.mu.RUnlock() + + if client, ok := p.clients[tailnet]; ok { + if ready, _ := p.readiness[tailnet]; !ready { + return nil, fmt.Errorf("%w: %s", ErrNotReady, tailnet) + } + + return client, nil + } + + return nil, fmt.Errorf("%w: %s", ErrClientNotFound, tailnet) +} diff --git a/k8s-operator/utils.go b/k8s-operator/utils.go index 043a9d7b54c7a..d83d98e0cc8ca 100644 --- a/k8s-operator/utils.go +++ b/k8s-operator/utils.go @@ -7,6 +7,8 @@ package kube import ( + "crypto/sha256" + "encoding/hex" "fmt" "tailscale.com/tailcfg" @@ -50,3 +52,17 @@ func CapVerFromFileName(name string) (tailcfg.CapabilityVersion, error) { _, err := fmt.Sscanf(name, "cap-%d.hujson", &cap) return cap, err } + +// TruncateLabelValue truncates a Kubernetes label value to fit within the +// 63-character limit. If the value exceeds the limit, it is truncated and a +// short hash suffix is appended to preserve uniqueness. +func TruncateLabelValue(val string) string { + const maxLen = 63 + if len(val) <= maxLen { + return val + } + hash := sha256.Sum256([]byte(val)) + suffix := hex.EncodeToString(hash[:4]) // 8 hex chars + truncated := val[:maxLen-len(suffix)-1] + return truncated + "-" + suffix +} diff --git a/k8s-operator/utils_test.go b/k8s-operator/utils_test.go new file mode 100644 index 0000000000000..7a30df6b4e708 --- /dev/null +++ b/k8s-operator/utils_test.go @@ -0,0 +1,78 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package kube + +import ( + "strings" + "testing" +) + +func TestTruncateLabelValue(t *testing.T) { + tests := []struct { + name string + input string + want string // empty means expect input unchanged + }{ + { + name: "short-value-unchanged", + input: "my-service", + }, + { + name: "exactly-63-chars-unchanged", + input: strings.Repeat("a", 63), + }, + { + name: "64-chars-gets-truncated", + input: strings.Repeat("a", 64), + }, + { + name: "very-long-value-gets-truncated", + input: "tailscale-nginx-clickhouse-o11y-server-https-with-extra-long-suffix-that-exceeds-limit", + }, + { + name: "253-chars-max-k8s-resource-name", + input: strings.Repeat("x", 253), + }, + { + name: "empty-string-unchanged", + input: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := TruncateLabelValue(tt.input) + if len(got) > 63 { + t.Errorf("TruncateLabelValue(%q) = %q (len %d), exceeds 63 chars", tt.input, got, len(got)) + } + if len(tt.input) <= 63 && got != tt.input { + t.Errorf("TruncateLabelValue(%q) = %q, want unchanged input", tt.input, got) + } + if len(tt.input) > 63 && got == tt.input { + t.Errorf("TruncateLabelValue(%q) was not truncated", tt.input) + } + }) + } +} + +func TestTruncateLabelValueDeterministic(t *testing.T) { + input := strings.Repeat("a", 100) + first := TruncateLabelValue(input) + for i := 0; i < 10; i++ { + got := TruncateLabelValue(input) + if got != first { + t.Fatalf("non-deterministic: got %q, want %q", got, first) + } + } +} + +func TestTruncateLabelValueUniqueness(t *testing.T) { + // Two inputs sharing a long prefix but differing at the end should produce different outputs. + a := strings.Repeat("a", 100) + "-one" + b := strings.Repeat("a", 100) + "-two" + if TruncateLabelValue(a) == TruncateLabelValue(b) { + t.Errorf("collision: %q and %q produce the same truncated label", a, b) + } +} diff --git a/kube/authkey/authkey.go b/kube/authkey/authkey.go new file mode 100644 index 0000000000000..f544a0c81f010 --- /dev/null +++ b/kube/authkey/authkey.go @@ -0,0 +1,122 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +// Package authkey provides shared logic for handling auth key reissue +// requests between tailnet clients (containerboot, k8s-proxy) and the +// operator. +// +// When a client fails to authenticate (expired key, single-use key already +// used), it signals the operator by setting a marker in its state Secret. +// The operator responds by deleting the old device and issuing a new auth +// key. The client watches for the new key and restarts to apply it. +package authkey + +import ( + "context" + "fmt" + "log" + "time" + + "tailscale.com/ipn" + "tailscale.com/ipn/conffile" + "tailscale.com/kube/kubeapi" + "tailscale.com/kube/kubeclient" + "tailscale.com/kube/kubetypes" +) + +const ( + TailscaleContainerFieldManager = "tailscale-container" +) + +// SetReissueAuthKey sets the reissue_authkey marker in the state Secret to +// signal to the operator that a new auth key is needed. The marker value is +// the auth key that failed to authenticate. +func SetReissueAuthKey(ctx context.Context, kc kubeclient.Client, stateSecretName string, authKey string, fieldManager string) error { + s := &kubeapi.Secret{ + Data: map[string][]byte{ + kubetypes.KeyReissueAuthkey: []byte(authKey), + }, + } + + log.Printf("Requesting a new auth key from operator") + return kc.StrategicMergePatchSecret(ctx, stateSecretName, s, fieldManager) +} + +// ClearReissueAuthKey removes the reissue_authkey marker from the state Secret +// to signal to the operator that we've successfully received the new key. +func ClearReissueAuthKey(ctx context.Context, kc kubeclient.Client, stateSecretName string, fieldManager string) error { + existing, err := kc.GetSecret(ctx, stateSecretName) + if err != nil { + return fmt.Errorf("error getting state secret: %w", err) + } + + s := &kubeapi.Secret{ + Data: map[string][]byte{ + kubetypes.KeyReissueAuthkey: nil, + kubetypes.KeyDeviceID: nil, + kubetypes.KeyDeviceFQDN: nil, + kubetypes.KeyDeviceIPs: nil, + string(ipn.MachineKeyStateKey): nil, + string(ipn.CurrentProfileStateKey): nil, + string(ipn.KnownProfilesStateKey): nil, + }, + } + + if profileKey := string(existing.Data[string(ipn.CurrentProfileStateKey)]); profileKey != "" { + s.Data[profileKey] = nil + } + + return kc.StrategicMergePatchSecret(ctx, stateSecretName, s, fieldManager) +} + +// WaitForAuthKeyReissue polls getAuthKey for a new auth key different from +// oldAuthKey, returning when one is found or maxWait expires. If notify is +// non-nil, it is used to wake the loop on config changes; otherwise it falls +// back to periodic polling. The clearFn callback is called when a new key is +// detected, to clear the reissue marker from the state Secret. +func WaitForAuthKeyReissue(ctx context.Context, oldAuthKey string, maxWait time.Duration, getAuthKey func() string, clearFn func(context.Context) error, notify <-chan struct{}) error { + log.Printf("Waiting for operator to provide new auth key (max wait: %v)", maxWait) + + ctx, cancel := context.WithTimeout(ctx, maxWait) + defer cancel() + + pollInterval := 5 * time.Second + pt := time.NewTicker(pollInterval) + defer pt.Stop() + + start := time.Now() + + for { + select { + case <-ctx.Done(): + return fmt.Errorf("timeout waiting for auth key reissue after %v", maxWait) + case <-pt.C: + case <-notify: + } + + newAuthKey := getAuthKey() + if newAuthKey != "" && newAuthKey != oldAuthKey { + log.Printf("New auth key received from operator after %v", time.Since(start).Round(time.Second)) + if err := clearFn(ctx); err != nil { + log.Printf("Warning: failed to clear reissue request: %v", err) + } + return nil + } + + if notify == nil { + log.Printf("Waiting for new auth key from operator (%v elapsed)", time.Since(start).Round(time.Second)) + } + } +} + +// AuthKeyFromConfig extracts the auth key from a tailscaled config file. +// Returns empty string if the file cannot be read or contains no auth key. +func AuthKeyFromConfig(path string) string { + if cfg, err := conffile.Load(path); err == nil && cfg.Parsed.AuthKey != nil { + return *cfg.Parsed.AuthKey + } + + return "" +} diff --git a/kube/authkey/authkey_test.go b/kube/authkey/authkey_test.go new file mode 100644 index 0000000000000..268bc46d6ac3e --- /dev/null +++ b/kube/authkey/authkey_test.go @@ -0,0 +1,124 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package authkey + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/google/go-cmp/cmp" + "tailscale.com/ipn" + "tailscale.com/kube/kubeapi" + "tailscale.com/kube/kubeclient" + "tailscale.com/kube/kubetypes" +) + +func TestSetReissueAuthKey(t *testing.T) { + var patched map[string][]byte + kc := &kubeclient.FakeClient{ + StrategicMergePatchSecretImpl: func(ctx context.Context, name string, secret *kubeapi.Secret, _ string) error { + patched = secret.Data + return nil + }, + } + + err := SetReissueAuthKey(context.Background(), kc, "test-secret", "old-auth-key", TailscaleContainerFieldManager) + if err != nil { + t.Fatalf("SetReissueAuthKey() error = %v", err) + } + + want := map[string][]byte{ + kubetypes.KeyReissueAuthkey: []byte("old-auth-key"), + } + if diff := cmp.Diff(want, patched); diff != "" { + t.Errorf("SetReissueAuthKey() mismatch (-want +got):\n%s", diff) + } +} + +func TestClearReissueAuthKey(t *testing.T) { + var patched map[string][]byte + kc := &kubeclient.FakeClient{ + GetSecretImpl: func(ctx context.Context, name string) (*kubeapi.Secret, error) { + return &kubeapi.Secret{ + Data: map[string][]byte{ + "_current-profile": []byte("profile-abc1"), + "profile-abc1": []byte("some-profile-data"), + "_machinekey": []byte("machine-key-data"), + }, + }, nil + }, + StrategicMergePatchSecretImpl: func(ctx context.Context, name string, secret *kubeapi.Secret, _ string) error { + patched = secret.Data + return nil + }, + } + + err := ClearReissueAuthKey(context.Background(), kc, "test-secret", TailscaleContainerFieldManager) + if err != nil { + t.Fatalf("ClearReissueAuthKey() error = %v", err) + } + + want := map[string][]byte{ + kubetypes.KeyReissueAuthkey: nil, + kubetypes.KeyDeviceID: nil, + kubetypes.KeyDeviceFQDN: nil, + kubetypes.KeyDeviceIPs: nil, + string(ipn.MachineKeyStateKey): nil, + string(ipn.CurrentProfileStateKey): nil, + string(ipn.KnownProfilesStateKey): nil, + "profile-abc1": nil, + } + if diff := cmp.Diff(want, patched); diff != "" { + t.Errorf("ClearReissueAuthKey() mismatch (-want +got):\n%s", diff) + } +} + +func TestAuthKeyFromConfig(t *testing.T) { + for name, tc := range map[string]struct { + configContent string + want string + }{ + "valid_config_with_authkey": { + configContent: `{"Version":"alpha0","AuthKey":"test-auth-key"}`, + want: "test-auth-key", + }, + "valid_config_without_authkey": { + configContent: `{"Version":"alpha0"}`, + want: "", + }, + "invalid_config": { + configContent: `not valid json`, + want: "", + }, + "empty_config": { + configContent: ``, + want: "", + }, + } { + t.Run(name, func(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.json") + + if err := os.WriteFile(configPath, []byte(tc.configContent), 0600); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + + got := AuthKeyFromConfig(configPath) + if got != tc.want { + t.Errorf("AuthKeyFromConfig() = %q, want %q", got, tc.want) + } + }) + } + + t.Run("nonexistent_file", func(t *testing.T) { + got := AuthKeyFromConfig("/nonexistent/path/config.json") + if got != "" { + t.Errorf("AuthKeyFromConfig() = %q, want empty string for nonexistent file", got) + } + }) +} diff --git a/kube/certs/certs.go b/kube/certs/certs.go index 4c8ac88b6b624..fd7c82a100dd7 100644 --- a/kube/certs/certs.go +++ b/kube/certs/certs.go @@ -171,8 +171,9 @@ func (cm *CertManager) runCertLoop(ctx context.Context, domain string) { } } -// waitForCertDomain ensures the requested domain is in the list of allowed -// domains before issuing the cert for the first time. +// domains before issuing the cert for the first time. It uses the IPN bus +// only as a wake-up trigger (Notify.SelfChange) and queries the current +// cert domains explicitly via [LocalClient.CertDomains]. func (cm *CertManager) waitForCertDomain(ctx context.Context, domain string) error { w, err := cm.lc.WatchIPNBus(ctx, ipn.NotifyInitialNetMap) if err != nil { @@ -185,11 +186,14 @@ func (cm *CertManager) waitForCertDomain(ctx context.Context, domain string) err if err != nil { return err } - if n.NetMap == nil { + if n.SelfChange == nil { continue } - - if slices.Contains(n.NetMap.DNS.CertDomains, domain) { + domains, err := cm.lc.CertDomains(ctx) + if err != nil { + continue + } + if slices.Contains(domains, domain) { return nil } } diff --git a/kube/health/healthz.go b/kube/health/healthz.go index 53888922bb940..e9b459fc19e76 100644 --- a/kube/health/healthz.go +++ b/kube/health/healthz.go @@ -65,8 +65,8 @@ func (h *Healthz) MonitorHealth(ctx context.Context, lc *local.Client) error { return err } - if n.NetMap != nil { - h.Update(n.NetMap.SelfNode.Addresses().Len() != 0) + if self := n.SelfChange; self != nil { + h.Update(len(self.Addresses) != 0) } } } diff --git a/kube/k8s-proxy/conf/conf_test.go b/kube/k8s-proxy/conf/conf_test.go index 4034bf3cb7752..0c26b4242e92f 100644 --- a/kube/k8s-proxy/conf/conf_test.go +++ b/kube/k8s-proxy/conf/conf_test.go @@ -10,7 +10,6 @@ import ( "testing" "github.com/google/go-cmp/cmp" - "tailscale.com/types/ptr" ) // Test that the config file can be at the root of the object, or in a versioned sub-object. @@ -23,17 +22,17 @@ func TestVersionedConfig(t *testing.T) { }{ "root_config_v1alpha1": { inputConfig: `{"version": "v1alpha1", "authKey": "abc123"}`, - expectedConfig: ConfigV1Alpha1{AuthKey: ptr.To("abc123")}, + expectedConfig: ConfigV1Alpha1{AuthKey: new("abc123")}, }, "backwards_compat_v1alpha1_config": { // Client doesn't know about v1beta1, so it should read in v1alpha1. inputConfig: `{"version": "v1beta1", "beta-key": "beta-value", "authKey": "def456", "v1alpha1": {"authKey": "abc123"}}`, - expectedConfig: ConfigV1Alpha1{AuthKey: ptr.To("abc123")}, + expectedConfig: ConfigV1Alpha1{AuthKey: new("abc123")}, }, "unknown_key_allowed": { // Adding new keys to the config doesn't require a version bump. inputConfig: `{"version": "v1alpha1", "unknown-key": "unknown-value", "authKey": "abc123"}`, - expectedConfig: ConfigV1Alpha1{AuthKey: ptr.To("abc123")}, + expectedConfig: ConfigV1Alpha1{AuthKey: new("abc123")}, }, "version_only_no_authkey": { inputConfig: `{"version": "v1alpha1"}`, diff --git a/kube/localclient/fake-client.go b/kube/localclient/fake-client.go index a244ce31a10c9..7ecada1134cd8 100644 --- a/kube/localclient/fake-client.go +++ b/kube/localclient/fake-client.go @@ -12,9 +12,10 @@ import ( type FakeLocalClient struct { FakeIPNBusWatcher - SetServeCalled bool - EditPrefsCalls []*ipn.MaskedPrefs - GetPrefsResult *ipn.Prefs + SetServeCalled bool + EditPrefsCalls []*ipn.MaskedPrefs + GetPrefsResult *ipn.Prefs + CertDomainsResult []string } func (m *FakeLocalClient) SetServeConfig(ctx context.Context, cfg *ipn.ServeConfig) error { @@ -45,6 +46,10 @@ func (f *FakeLocalClient) CertPair(ctx context.Context, domain string) ([]byte, return nil, nil, fmt.Errorf("CertPair not implemented") } +func (f *FakeLocalClient) CertDomains(ctx context.Context) ([]string, error) { + return f.CertDomainsResult, nil +} + type FakeIPNBusWatcher struct { NotifyChan chan ipn.Notify } diff --git a/kube/localclient/local-client.go b/kube/localclient/local-client.go index b8d40f4067c0e..f759568ba195f 100644 --- a/kube/localclient/local-client.go +++ b/kube/localclient/local-client.go @@ -19,6 +19,7 @@ type LocalClient interface { WatchIPNBus(ctx context.Context, mask ipn.NotifyWatchOpt) (IPNBusWatcher, error) SetServeConfig(context.Context, *ipn.ServeConfig) error EditPrefs(ctx context.Context, mp *ipn.MaskedPrefs) (*ipn.Prefs, error) + CertDomains(ctx context.Context) ([]string, error) CertIssuer } @@ -57,3 +58,7 @@ func (lc *localClient) WatchIPNBus(ctx context.Context, mask ipn.NotifyWatchOpt) func (lc *localClient) CertPair(ctx context.Context, domain string) ([]byte, []byte, error) { return lc.lc.CertPair(ctx, domain) } + +func (lc *localClient) CertDomains(ctx context.Context) ([]string, error) { + return lc.lc.CertDomains(ctx) +} diff --git a/kube/state/state.go b/kube/state/state.go index ebedb2f725b3d..220eb439f80a4 100644 --- a/kube/state/state.go +++ b/kube/state/state.go @@ -30,19 +30,8 @@ const ( keyDeviceFQDN = ipn.StateKey(kubetypes.KeyDeviceFQDN) ) -// SetInitialKeys sets Pod UID and cap ver and clears tailnet device state -// keys to help stop the operator using stale tailnet device state. +// SetInitialKeys sets Pod UID and cap ver. func SetInitialKeys(store ipn.StateStore, podUID string) error { - // Clear device state keys first so the operator knows if the pod UID - // matches, the other values are definitely not stale. - for _, key := range []ipn.StateKey{keyDeviceID, keyDeviceFQDN, keyDeviceIPs} { - if _, err := store.ReadState(key); err == nil { - if err := store.WriteState(key, nil); err != nil { - return fmt.Errorf("error writing %q to state store: %w", key, err) - } - } - } - if err := store.WriteState(keyPodUID, []byte(podUID)); err != nil { return fmt.Errorf("error writing pod UID to state store: %w", err) } @@ -55,9 +44,9 @@ func SetInitialKeys(store ipn.StateStore, podUID string) error { // KeepKeysUpdated sets state store keys consistent with containerboot to // signal proxy readiness to the operator. It runs until its context is -// cancelled or it hits an error. The passed in next function is expected to be -// from a local.IPNBusWatcher that is at least subscribed to -// ipn.NotifyInitialNetMap. +// cancelled or it hits an error. It watches the IPN bus for SelfChange +// notifications (which fire whenever the self node changes) and reads +// the new self node directly from the notify. func KeepKeysUpdated(ctx context.Context, store ipn.StateStore, lc klc.LocalClient) error { w, err := lc.WatchIPNBus(ctx, ipn.NotifyInitialNetMap) if err != nil { @@ -74,25 +63,26 @@ func KeepKeysUpdated(ctx context.Context, store ipn.StateStore, lc klc.LocalClie } return err } - if n.NetMap == nil { + self := n.SelfChange + if self == nil { continue } - if deviceID := n.NetMap.SelfNode.StableID(); deephash.Update(¤tDeviceID, &deviceID) { + if deviceID := self.StableID; deephash.Update(¤tDeviceID, &deviceID) { if err := store.WriteState(keyDeviceID, []byte(deviceID)); err != nil { return fmt.Errorf("failed to store device ID in state: %w", err) } } - if fqdn := n.NetMap.SelfNode.Name(); deephash.Update(¤tDeviceFQDN, &fqdn) { + if fqdn := self.Name; deephash.Update(¤tDeviceFQDN, &fqdn) { if err := store.WriteState(keyDeviceFQDN, []byte(fqdn)); err != nil { return fmt.Errorf("failed to store device FQDN in state: %w", err) } } - if addrs := n.NetMap.SelfNode.Addresses(); deephash.Update(¤tDeviceIPs, &addrs) { + if addrs := self.Addresses; deephash.Update(¤tDeviceIPs, &addrs) { var deviceIPs []string - for _, addr := range addrs.AsSlice() { + for _, addr := range addrs { deviceIPs = append(deviceIPs, addr.Addr().String()) } deviceIPsValue, err := json.Marshal(deviceIPs) diff --git a/kube/state/state_test.go b/kube/state/state_test.go index 9b2ce69be5599..5c438377e814f 100644 --- a/kube/state/state_test.go +++ b/kube/state/state_test.go @@ -18,7 +18,6 @@ import ( klc "tailscale.com/kube/localclient" "tailscale.com/tailcfg" "tailscale.com/types/logger" - "tailscale.com/types/netmap" ) func TestSetInitialStateKeys(t *testing.T) { @@ -58,9 +57,9 @@ func TestSetInitialStateKeys(t *testing.T) { expected: map[ipn.StateKey][]byte{ keyPodUID: podUID, keyCapVer: expectedCapVer, - keyDeviceID: nil, - keyDeviceFQDN: nil, - keyDeviceIPs: nil, + keyDeviceID: []byte("existing-device-id"), + keyDeviceFQDN: []byte("existing-device-fqdn"), + keyDeviceIPs: []byte(`["1.2.3.4"]`), }, }, } { @@ -133,12 +132,10 @@ func TestKeepStateKeysUpdated(t *testing.T) { { name: "authed", notify: ipn.Notify{ - NetMap: &netmap.NetworkMap{ - SelfNode: (&tailcfg.Node{ - StableID: "TESTCTRL00000001", - Name: "test-node.test.ts.net", - Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32"), netip.MustParsePrefix("fd7a:115c:a1e0:ab12:4843:cd96:0:1/128")}, - }).View(), + SelfChange: &tailcfg.Node{ + StableID: "TESTCTRL00000001", + Name: "test-node.test.ts.net", + Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32"), netip.MustParsePrefix("fd7a:115c:a1e0:ab12:4843:cd96:0:1/128")}, }, }, expected: []string{ @@ -150,12 +147,10 @@ func TestKeepStateKeysUpdated(t *testing.T) { { name: "updated_fields", notify: ipn.Notify{ - NetMap: &netmap.NetworkMap{ - SelfNode: (&tailcfg.Node{ - StableID: "TESTCTRL00000001", - Name: "updated.test.ts.net", - Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.250/32")}, - }).View(), + SelfChange: &tailcfg.Node{ + StableID: "TESTCTRL00000001", + Name: "updated.test.ts.net", + Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.250/32")}, }, }, expected: []string{ diff --git a/licenses/android.md b/licenses/android.md index 15098f0752e79..07c97948e1da3 100644 --- a/licenses/android.md +++ b/licenses/android.md @@ -4,6 +4,7 @@ The following open source dependencies are used to build the [Tailscale Android Client][]. See also the dependencies in the [Tailscale CLI][]. [Tailscale Android Client]: https://github.com/tailscale/tailscale-android +[Tailscale CLI]: ./tailscale.md ## Go Packages @@ -21,31 +22,32 @@ Client][]. See also the dependencies in the [Tailscale CLI][]. - [github.com/huin/goupnp](https://pkg.go.dev/github.com/huin/goupnp) ([BSD-2-Clause](https://github.com/huin/goupnp/blob/v1.3.0/LICENSE)) - [github.com/insomniacslk/dhcp](https://pkg.go.dev/github.com/insomniacslk/dhcp) ([BSD-3-Clause](https://github.com/insomniacslk/dhcp/blob/8c70d406f6d2/LICENSE)) - [github.com/jellydator/ttlcache/v3](https://pkg.go.dev/github.com/jellydator/ttlcache/v3) ([MIT](https://github.com/jellydator/ttlcache/blob/v3.1.0/LICENSE)) - - [github.com/klauspost/compress](https://pkg.go.dev/github.com/klauspost/compress) ([Apache-2.0](https://github.com/klauspost/compress/blob/v1.18.2/LICENSE)) - - [github.com/klauspost/compress/internal/snapref](https://pkg.go.dev/github.com/klauspost/compress/internal/snapref) ([BSD-3-Clause](https://github.com/klauspost/compress/blob/v1.18.2/internal/snapref/LICENSE)) - - [github.com/klauspost/compress/zstd/internal/xxhash](https://pkg.go.dev/github.com/klauspost/compress/zstd/internal/xxhash) ([MIT](https://github.com/klauspost/compress/blob/v1.18.2/zstd/internal/xxhash/LICENSE.txt)) + - [github.com/klauspost/compress](https://pkg.go.dev/github.com/klauspost/compress) ([Apache-2.0](https://github.com/klauspost/compress/blob/v1.18.5/LICENSE)) + - [github.com/klauspost/compress/internal/snapref](https://pkg.go.dev/github.com/klauspost/compress/internal/snapref) ([BSD-3-Clause](https://github.com/klauspost/compress/blob/v1.18.5/internal/snapref/LICENSE)) + - [github.com/klauspost/compress/zstd/internal/xxhash](https://pkg.go.dev/github.com/klauspost/compress/zstd/internal/xxhash) ([MIT](https://github.com/klauspost/compress/blob/v1.18.5/zstd/internal/xxhash/LICENSE.txt)) - [github.com/kortschak/wol](https://pkg.go.dev/github.com/kortschak/wol) ([BSD-3-Clause](https://github.com/kortschak/wol/blob/da482cc4850a/LICENSE)) + - [github.com/mdlayher/netlink](https://pkg.go.dev/github.com/mdlayher/netlink) ([MIT](https://github.com/mdlayher/netlink/blob/fbb4dce95f42/LICENSE.md)) - [github.com/mdlayher/socket](https://pkg.go.dev/github.com/mdlayher/socket) ([MIT](https://github.com/mdlayher/socket/blob/v0.5.0/LICENSE.md)) - [github.com/pierrec/lz4/v4](https://pkg.go.dev/github.com/pierrec/lz4/v4) ([BSD-3-Clause](https://github.com/pierrec/lz4/blob/v4.1.25/LICENSE)) - [github.com/pires/go-proxyproto](https://pkg.go.dev/github.com/pires/go-proxyproto) ([Apache-2.0](https://github.com/pires/go-proxyproto/blob/v0.8.1/LICENSE)) - [github.com/tailscale/peercred](https://pkg.go.dev/github.com/tailscale/peercred) ([BSD-3-Clause](https://github.com/tailscale/peercred/blob/35a0c7bd7edc/LICENSE)) - [github.com/tailscale/tailscale-android/libtailscale](https://pkg.go.dev/github.com/tailscale/tailscale-android/libtailscale) ([BSD-3-Clause](https://github.com/tailscale/tailscale-android/blob/HEAD/LICENSE)) - - [github.com/tailscale/wireguard-go](https://pkg.go.dev/github.com/tailscale/wireguard-go) ([MIT](https://github.com/tailscale/wireguard-go/blob/1d0488a3d7da/LICENSE)) + - [github.com/tailscale/wireguard-go](https://pkg.go.dev/github.com/tailscale/wireguard-go) ([MIT](https://github.com/tailscale/wireguard-go/blob/e3ac4a0afb4e/LICENSE)) - [github.com/tailscale/xnet/webdav](https://pkg.go.dev/github.com/tailscale/xnet/webdav) ([BSD-3-Clause](https://github.com/tailscale/xnet/blob/8497ac4dab2e/LICENSE)) - [github.com/u-root/uio](https://pkg.go.dev/github.com/u-root/uio) ([BSD-3-Clause](https://github.com/u-root/uio/blob/d2acac8f3701/LICENSE)) - [github.com/x448/float16](https://pkg.go.dev/github.com/x448/float16) ([MIT](https://github.com/x448/float16/blob/v0.8.4/LICENSE)) - [go4.org/mem](https://pkg.go.dev/go4.org/mem) ([Apache-2.0](https://github.com/go4org/mem/blob/ae6ca9944745/LICENSE)) - [go4.org/netipx](https://pkg.go.dev/go4.org/netipx) ([BSD-3-Clause](https://github.com/go4org/netipx/blob/fdeea329fbba/LICENSE)) - - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.46.0:LICENSE)) + - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.50.0:LICENSE)) - [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/b7579e27:LICENSE)) - [golang.org/x/mobile](https://pkg.go.dev/golang.org/x/mobile) ([BSD-3-Clause](https://cs.opensource.google/go/x/mobile/+/81131f64:LICENSE)) - - [golang.org/x/mod/semver](https://pkg.go.dev/golang.org/x/mod/semver) ([BSD-3-Clause](https://cs.opensource.google/go/x/mod/+/v0.30.0:LICENSE)) - - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.48.0:LICENSE)) - - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.19.0:LICENSE)) - - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.40.0:LICENSE)) - - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.38.0:LICENSE)) - - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.32.0:LICENSE)) + - [golang.org/x/mod/semver](https://pkg.go.dev/golang.org/x/mod/semver) ([BSD-3-Clause](https://cs.opensource.google/go/x/mod/+/v0.35.0:LICENSE)) + - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.53.0:LICENSE)) + - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.20.0:LICENSE)) + - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.43.0:LICENSE)) + - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.42.0:LICENSE)) + - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.36.0:LICENSE)) - [golang.org/x/time/rate](https://pkg.go.dev/golang.org/x/time/rate) ([BSD-3-Clause](https://cs.opensource.google/go/x/time/+/v0.12.0:LICENSE)) - - [golang.org/x/tools](https://pkg.go.dev/golang.org/x/tools) ([BSD-3-Clause](https://cs.opensource.google/go/x/tools/+/v0.39.0:LICENSE)) - - [gvisor.dev/gvisor/pkg](https://pkg.go.dev/gvisor.dev/gvisor/pkg) ([Apache-2.0](https://github.com/google/gvisor/blob/9414b50a5633/LICENSE)) + - [golang.org/x/tools](https://pkg.go.dev/golang.org/x/tools) ([BSD-3-Clause](https://cs.opensource.google/go/x/tools/+/v0.44.0:LICENSE)) + - [gvisor.dev/gvisor/pkg](https://pkg.go.dev/gvisor.dev/gvisor/pkg) ([Apache-2.0](https://github.com/google/gvisor/blob/573d5e7127a8/LICENSE)) - [tailscale.com](https://pkg.go.dev/tailscale.com) ([BSD-3-Clause](https://github.com/tailscale/tailscale/blob/HEAD/LICENSE)) diff --git a/licenses/apple.md b/licenses/apple.md index f7989fe250a63..a7bc6af8cb94b 100644 --- a/licenses/apple.md +++ b/licenses/apple.md @@ -12,14 +12,14 @@ See also the dependencies in the [Tailscale CLI][]. - [filippo.io/edwards25519](https://pkg.go.dev/filippo.io/edwards25519) ([BSD-3-Clause](https://github.com/FiloSottile/edwards25519/blob/v1.2.0/LICENSE)) - - [github.com/aws/aws-sdk-go-v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/v1.41.0/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/v1.41.5/LICENSE.txt)) - [github.com/aws/aws-sdk-go-v2/config](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/config) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/config/v1.32.5/config/LICENSE.txt)) - [github.com/aws/aws-sdk-go-v2/credentials](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/credentials) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/credentials/v1.19.5/credentials/LICENSE.txt)) - [github.com/aws/aws-sdk-go-v2/feature/ec2/imds](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/feature/ec2/imds) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/feature/ec2/imds/v1.18.16/feature/ec2/imds/LICENSE.txt)) - [github.com/aws/aws-sdk-go-v2/internal/configsources](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/configsources) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/configsources/v1.4.16/internal/configsources/LICENSE.txt)) - [github.com/aws/aws-sdk-go-v2/internal/endpoints/v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/endpoints/v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/endpoints/v2.7.16/internal/endpoints/v2/LICENSE.txt)) - [github.com/aws/aws-sdk-go-v2/internal/ini](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/ini) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/ini/v1.8.4/internal/ini/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/aws-sdk-go-v2/blob/v1.41.0/internal/sync/singleflight/LICENSE)) + - [github.com/aws/aws-sdk-go-v2/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/aws-sdk-go-v2/blob/v1.41.5/internal/sync/singleflight/LICENSE)) - [github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/internal/accept-encoding/v1.13.4/service/internal/accept-encoding/LICENSE.txt)) - [github.com/aws/aws-sdk-go-v2/service/internal/presigned-url](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/internal/presigned-url/v1.13.16/service/internal/presigned-url/LICENSE.txt)) - [github.com/aws/aws-sdk-go-v2/service/signin](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/signin) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/signin/v1.0.4/service/signin/LICENSE.txt)) @@ -27,8 +27,8 @@ See also the dependencies in the [Tailscale CLI][]. - [github.com/aws/aws-sdk-go-v2/service/sso](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/sso) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/sso/v1.30.7/service/sso/LICENSE.txt)) - [github.com/aws/aws-sdk-go-v2/service/ssooidc](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ssooidc) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/ssooidc/v1.35.12/service/ssooidc/LICENSE.txt)) - [github.com/aws/aws-sdk-go-v2/service/sts](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/sts) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/sts/v1.41.5/service/sts/LICENSE.txt)) - - [github.com/aws/smithy-go](https://pkg.go.dev/github.com/aws/smithy-go) ([Apache-2.0](https://github.com/aws/smithy-go/blob/v1.24.0/LICENSE)) - - [github.com/aws/smithy-go/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/smithy-go/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/smithy-go/blob/v1.24.0/internal/sync/singleflight/LICENSE)) + - [github.com/aws/smithy-go](https://pkg.go.dev/github.com/aws/smithy-go) ([Apache-2.0](https://github.com/aws/smithy-go/blob/v1.24.2/LICENSE)) + - [github.com/aws/smithy-go/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/smithy-go/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/smithy-go/blob/v1.24.2/internal/sync/singleflight/LICENSE)) - [github.com/coreos/go-iptables/iptables](https://pkg.go.dev/github.com/coreos/go-iptables/iptables) ([Apache-2.0](https://github.com/coreos/go-iptables/blob/65c67c9f46e6/LICENSE)) - [github.com/creachadair/msync/trigger](https://pkg.go.dev/github.com/creachadair/msync/trigger) ([BSD-3-Clause](https://github.com/creachadair/msync/blob/v0.8.1/LICENSE)) - [github.com/digitalocean/go-smbios/smbios](https://pkg.go.dev/github.com/digitalocean/go-smbios/smbios) ([Apache-2.0](https://github.com/digitalocean/go-smbios/blob/390a4f403a8e/LICENSE.md)) @@ -48,9 +48,9 @@ See also the dependencies in the [Tailscale CLI][]. - [github.com/jellydator/ttlcache/v3](https://pkg.go.dev/github.com/jellydator/ttlcache/v3) ([MIT](https://github.com/jellydator/ttlcache/blob/v3.1.0/LICENSE)) - [github.com/jmespath/go-jmespath](https://pkg.go.dev/github.com/jmespath/go-jmespath) ([Apache-2.0](https://github.com/jmespath/go-jmespath/blob/v0.4.0/LICENSE)) - [github.com/jsimonetti/rtnetlink](https://pkg.go.dev/github.com/jsimonetti/rtnetlink) ([MIT](https://github.com/jsimonetti/rtnetlink/blob/v1.4.1/LICENSE.md)) - - [github.com/klauspost/compress](https://pkg.go.dev/github.com/klauspost/compress) ([Apache-2.0](https://github.com/klauspost/compress/blob/v1.18.2/LICENSE)) - - [github.com/klauspost/compress/internal/snapref](https://pkg.go.dev/github.com/klauspost/compress/internal/snapref) ([BSD-3-Clause](https://github.com/klauspost/compress/blob/v1.18.2/internal/snapref/LICENSE)) - - [github.com/klauspost/compress/zstd/internal/xxhash](https://pkg.go.dev/github.com/klauspost/compress/zstd/internal/xxhash) ([MIT](https://github.com/klauspost/compress/blob/v1.18.2/zstd/internal/xxhash/LICENSE.txt)) + - [github.com/klauspost/compress](https://pkg.go.dev/github.com/klauspost/compress) ([Apache-2.0](https://github.com/klauspost/compress/blob/v1.18.5/LICENSE)) + - [github.com/klauspost/compress/internal/snapref](https://pkg.go.dev/github.com/klauspost/compress/internal/snapref) ([BSD-3-Clause](https://github.com/klauspost/compress/blob/v1.18.5/internal/snapref/LICENSE)) + - [github.com/klauspost/compress/zstd/internal/xxhash](https://pkg.go.dev/github.com/klauspost/compress/zstd/internal/xxhash) ([MIT](https://github.com/klauspost/compress/blob/v1.18.5/zstd/internal/xxhash/LICENSE.txt)) - [github.com/kortschak/wol](https://pkg.go.dev/github.com/kortschak/wol) ([BSD-3-Clause](https://github.com/kortschak/wol/blob/da482cc4850a/LICENSE)) - [github.com/mdlayher/genetlink](https://pkg.go.dev/github.com/mdlayher/genetlink) ([MIT](https://github.com/mdlayher/genetlink/blob/v1.3.2/LICENSE.md)) - [github.com/mdlayher/netlink](https://pkg.go.dev/github.com/mdlayher/netlink) ([MIT](https://github.com/mdlayher/netlink/blob/fbb4dce95f42/LICENSE.md)) @@ -63,21 +63,21 @@ See also the dependencies in the [Tailscale CLI][]. - [github.com/safchain/ethtool](https://pkg.go.dev/github.com/safchain/ethtool) ([Apache-2.0](https://github.com/safchain/ethtool/blob/v0.3.0/LICENSE)) - [github.com/tailscale/netlink](https://pkg.go.dev/github.com/tailscale/netlink) ([Apache-2.0](https://github.com/tailscale/netlink/blob/4d49adab4de7/LICENSE)) - [github.com/tailscale/peercred](https://pkg.go.dev/github.com/tailscale/peercred) ([BSD-3-Clause](https://github.com/tailscale/peercred/blob/35a0c7bd7edc/LICENSE)) - - [github.com/tailscale/wireguard-go](https://pkg.go.dev/github.com/tailscale/wireguard-go) ([MIT](https://github.com/tailscale/wireguard-go/blob/1d0488a3d7da/LICENSE)) + - [github.com/tailscale/wireguard-go](https://pkg.go.dev/github.com/tailscale/wireguard-go) ([MIT](https://github.com/tailscale/wireguard-go/blob/e3ac4a0afb4e/LICENSE)) - [github.com/tailscale/xnet/webdav](https://pkg.go.dev/github.com/tailscale/xnet/webdav) ([BSD-3-Clause](https://github.com/tailscale/xnet/blob/8497ac4dab2e/LICENSE)) - [github.com/u-root/uio](https://pkg.go.dev/github.com/u-root/uio) ([BSD-3-Clause](https://github.com/u-root/uio/blob/d2acac8f3701/LICENSE)) - [github.com/vishvananda/netns](https://pkg.go.dev/github.com/vishvananda/netns) ([Apache-2.0](https://github.com/vishvananda/netns/blob/v0.0.5/LICENSE)) - [github.com/x448/float16](https://pkg.go.dev/github.com/x448/float16) ([MIT](https://github.com/x448/float16/blob/v0.8.4/LICENSE)) - [go4.org/mem](https://pkg.go.dev/go4.org/mem) ([Apache-2.0](https://github.com/go4org/mem/blob/ae6ca9944745/LICENSE)) - [go4.org/netipx](https://pkg.go.dev/go4.org/netipx) ([BSD-3-Clause](https://github.com/go4org/netipx/blob/fdeea329fbba/LICENSE)) - - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.47.0:LICENSE)) - - [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/a4bb9ffd:LICENSE)) - - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.49.0:LICENSE)) - - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.19.0:LICENSE)) - - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.40.0:LICENSE)) - - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.39.0:LICENSE)) - - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.33.0:LICENSE)) - - [golang.org/x/time/rate](https://pkg.go.dev/golang.org/x/time/rate) ([BSD-3-Clause](https://cs.opensource.google/go/x/time/+/v0.12.0:LICENSE)) + - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.50.0:LICENSE)) + - [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/3dfff04d:LICENSE)) + - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.53.0:LICENSE)) + - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.20.0:LICENSE)) + - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.43.0:LICENSE)) + - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.42.0:LICENSE)) + - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.36.0:LICENSE)) + - [golang.org/x/time/rate](https://pkg.go.dev/golang.org/x/time/rate) ([BSD-3-Clause](https://cs.opensource.google/go/x/time/+/v0.15.0:LICENSE)) - [gvisor.dev/gvisor/pkg](https://pkg.go.dev/gvisor.dev/gvisor/pkg) ([Apache-2.0](https://github.com/google/gvisor/blob/573d5e7127a8/LICENSE)) - [tailscale.com](https://pkg.go.dev/tailscale.com) ([BSD-3-Clause](https://github.com/tailscale/tailscale/blob/HEAD/LICENSE)) diff --git a/licenses/tailscale.md b/licenses/tailscale.md index 5050b38db2178..01fdec26f327c 100644 --- a/licenses/tailscale.md +++ b/licenses/tailscale.md @@ -58,13 +58,14 @@ Some packages may only be included on certain architectures or operating systems - [github.com/jellydator/ttlcache/v3](https://pkg.go.dev/github.com/jellydator/ttlcache/v3) ([MIT](https://github.com/jellydator/ttlcache/blob/v3.1.0/LICENSE)) - [github.com/jmespath/go-jmespath](https://pkg.go.dev/github.com/jmespath/go-jmespath) ([Apache-2.0](https://github.com/jmespath/go-jmespath/blob/v0.4.0/LICENSE)) - [github.com/kballard/go-shellquote](https://pkg.go.dev/github.com/kballard/go-shellquote) ([MIT](https://github.com/kballard/go-shellquote/blob/95032a82bc51/LICENSE)) - - [github.com/klauspost/compress](https://pkg.go.dev/github.com/klauspost/compress) ([Apache-2.0](https://github.com/klauspost/compress/blob/v1.18.2/LICENSE)) - - [github.com/klauspost/compress/internal/snapref](https://pkg.go.dev/github.com/klauspost/compress/internal/snapref) ([BSD-3-Clause](https://github.com/klauspost/compress/blob/v1.18.2/internal/snapref/LICENSE)) - - [github.com/klauspost/compress/zstd/internal/xxhash](https://pkg.go.dev/github.com/klauspost/compress/zstd/internal/xxhash) ([MIT](https://github.com/klauspost/compress/blob/v1.18.2/zstd/internal/xxhash/LICENSE.txt)) + - [github.com/klauspost/compress](https://pkg.go.dev/github.com/klauspost/compress) ([Apache-2.0](https://github.com/klauspost/compress/blob/v1.18.5/LICENSE)) + - [github.com/klauspost/compress/internal/snapref](https://pkg.go.dev/github.com/klauspost/compress/internal/snapref) ([BSD-3-Clause](https://github.com/klauspost/compress/blob/v1.18.5/internal/snapref/LICENSE)) + - [github.com/klauspost/compress/zstd/internal/xxhash](https://pkg.go.dev/github.com/klauspost/compress/zstd/internal/xxhash) ([MIT](https://github.com/klauspost/compress/blob/v1.18.5/zstd/internal/xxhash/LICENSE.txt)) - [github.com/kortschak/wol](https://pkg.go.dev/github.com/kortschak/wol) ([BSD-3-Clause](https://github.com/kortschak/wol/blob/da482cc4850a/LICENSE)) - [github.com/kr/fs](https://pkg.go.dev/github.com/kr/fs) ([BSD-3-Clause](https://github.com/kr/fs/blob/v0.1.0/LICENSE)) - [github.com/mattn/go-colorable](https://pkg.go.dev/github.com/mattn/go-colorable) ([MIT](https://github.com/mattn/go-colorable/blob/v0.1.13/LICENSE)) - [github.com/mattn/go-isatty](https://pkg.go.dev/github.com/mattn/go-isatty) ([MIT](https://github.com/mattn/go-isatty/blob/v0.0.20/LICENSE)) + - [github.com/mdlayher/netlink](https://pkg.go.dev/github.com/mdlayher/netlink) ([MIT](https://github.com/mdlayher/netlink/blob/fbb4dce95f42/LICENSE.md)) - [github.com/mdlayher/socket](https://pkg.go.dev/github.com/mdlayher/socket) ([MIT](https://github.com/mdlayher/socket/blob/v0.5.0/LICENSE.md)) - [github.com/mitchellh/go-ps](https://pkg.go.dev/github.com/mitchellh/go-ps) ([MIT](https://github.com/mitchellh/go-ps/blob/v1.0.0/LICENSE.md)) - [github.com/peterbourgon/ff/v3](https://pkg.go.dev/github.com/peterbourgon/ff/v3) ([Apache-2.0](https://github.com/peterbourgon/ff/blob/v3.4.0/LICENSE)) @@ -73,11 +74,12 @@ Some packages may only be included on certain architectures or operating systems - [github.com/pkg/sftp](https://pkg.go.dev/github.com/pkg/sftp) ([BSD-2-Clause](https://github.com/pkg/sftp/blob/v1.13.6/LICENSE)) - [github.com/prometheus-community/pro-bing](https://pkg.go.dev/github.com/prometheus-community/pro-bing) ([MIT](https://github.com/prometheus-community/pro-bing/blob/v0.4.0/LICENSE)) - [github.com/skip2/go-qrcode](https://pkg.go.dev/github.com/skip2/go-qrcode) ([MIT](https://github.com/skip2/go-qrcode/blob/da1b6568686e/LICENSE)) - - [github.com/tailscale/certstore](https://pkg.go.dev/github.com/tailscale/certstore) ([MIT](https://github.com/tailscale/certstore/blob/d3fa0460f47e/LICENSE.md)) + - [github.com/tailscale/certstore](https://pkg.go.dev/github.com/tailscale/certstore) ([MIT](https://github.com/tailscale/certstore/blob/3638fb84b77d/LICENSE.md)) + - [github.com/tailscale/gliderssh](https://pkg.go.dev/github.com/tailscale/gliderssh) ([BSD-3-Clause](https://github.com/tailscale/gliderssh/blob/c1389c70ff89/LICENSE)) - [github.com/tailscale/go-winio](https://pkg.go.dev/github.com/tailscale/go-winio) ([MIT](https://github.com/tailscale/go-winio/blob/c4f33415bf55/LICENSE)) - [github.com/tailscale/web-client-prebuilt](https://pkg.go.dev/github.com/tailscale/web-client-prebuilt) ([BSD-3-Clause](https://github.com/tailscale/web-client-prebuilt/blob/d4cd19a26976/LICENSE)) - [github.com/tailscale/wf](https://pkg.go.dev/github.com/tailscale/wf) ([BSD-3-Clause](https://github.com/tailscale/wf/blob/6fbb0a674ee6/LICENSE)) - - [github.com/tailscale/wireguard-go](https://pkg.go.dev/github.com/tailscale/wireguard-go) ([MIT](https://github.com/tailscale/wireguard-go/blob/1d0488a3d7da/LICENSE)) + - [github.com/tailscale/wireguard-go](https://pkg.go.dev/github.com/tailscale/wireguard-go) ([MIT](https://github.com/tailscale/wireguard-go/blob/e3ac4a0afb4e/LICENSE)) - [github.com/tailscale/xnet/webdav](https://pkg.go.dev/github.com/tailscale/xnet/webdav) ([BSD-3-Clause](https://github.com/tailscale/xnet/blob/8497ac4dab2e/LICENSE)) - [github.com/toqueteos/webbrowser](https://pkg.go.dev/github.com/toqueteos/webbrowser) ([MIT](https://github.com/toqueteos/webbrowser/blob/v1.2.0/LICENSE.md)) - [github.com/u-root/u-root/pkg/termios](https://pkg.go.dev/github.com/u-root/u-root/pkg/termios) ([BSD-3-Clause](https://github.com/u-root/u-root/blob/v0.14.0/LICENSE)) @@ -86,15 +88,15 @@ Some packages may only be included on certain architectures or operating systems - [go.yaml.in/yaml/v2](https://pkg.go.dev/go.yaml.in/yaml/v2) ([Apache-2.0](https://github.com/yaml/go-yaml/blob/v2.4.2/LICENSE)) - [go4.org/mem](https://pkg.go.dev/go4.org/mem) ([Apache-2.0](https://github.com/go4org/mem/blob/ae6ca9944745/LICENSE)) - [go4.org/netipx](https://pkg.go.dev/go4.org/netipx) ([BSD-3-Clause](https://github.com/go4org/netipx/blob/fdeea329fbba/LICENSE)) - - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.46.0:LICENSE)) + - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.50.0:LICENSE)) - [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/b7579e27:LICENSE)) - [golang.org/x/image](https://pkg.go.dev/golang.org/x/image) ([BSD-3-Clause](https://cs.opensource.google/go/x/image/+/v0.27.0:LICENSE)) - - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.48.0:LICENSE)) - - [golang.org/x/oauth2](https://pkg.go.dev/golang.org/x/oauth2) ([BSD-3-Clause](https://cs.opensource.google/go/x/oauth2/+/v0.33.0:LICENSE)) - - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.19.0:LICENSE)) - - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.40.0:LICENSE)) - - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.38.0:LICENSE)) - - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.32.0:LICENSE)) + - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.53.0:LICENSE)) + - [golang.org/x/oauth2](https://pkg.go.dev/golang.org/x/oauth2) ([BSD-3-Clause](https://cs.opensource.google/go/x/oauth2/+/v0.36.0:LICENSE)) + - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.20.0:LICENSE)) + - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.43.0:LICENSE)) + - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.42.0:LICENSE)) + - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.36.0:LICENSE)) - [golang.org/x/time/rate](https://pkg.go.dev/golang.org/x/time/rate) ([BSD-3-Clause](https://cs.opensource.google/go/x/time/+/v0.12.0:LICENSE)) - [golang.zx2c4.com/wintun](https://pkg.go.dev/golang.zx2c4.com/wintun) ([MIT](https://git.zx2c4.com/wintun-go/tree/LICENSE?id=0fa3db229ce2)) - [golang.zx2c4.com/wireguard/windows/tunnel/winipcfg](https://pkg.go.dev/golang.zx2c4.com/wireguard/windows/tunnel/winipcfg) ([MIT](https://git.zx2c4.com/wireguard-windows/tree/COPYING?h=v0.5.3)) @@ -102,5 +104,4 @@ Some packages may only be included on certain architectures or operating systems - [k8s.io/client-go/util/homedir](https://pkg.go.dev/k8s.io/client-go/util/homedir) ([Apache-2.0](https://github.com/kubernetes/client-go/blob/v0.34.0/LICENSE)) - [sigs.k8s.io/yaml](https://pkg.go.dev/sigs.k8s.io/yaml) ([Apache-2.0](https://github.com/kubernetes-sigs/yaml/blob/v1.6.0/LICENSE)) - [tailscale.com](https://pkg.go.dev/tailscale.com) ([BSD-3-Clause](https://github.com/tailscale/tailscale/blob/HEAD/LICENSE)) - - [tailscale.com/tempfork/gliderlabs/ssh](https://pkg.go.dev/tailscale.com/tempfork/gliderlabs/ssh) ([BSD-3-Clause](https://github.com/tailscale/tailscale/blob/HEAD/tempfork/gliderlabs/ssh/LICENSE)) - [tailscale.com/tempfork/spf13/cobra](https://pkg.go.dev/tailscale.com/tempfork/spf13/cobra) ([Apache-2.0](https://github.com/tailscale/tailscale/blob/HEAD/tempfork/spf13/cobra/LICENSE.txt)) diff --git a/licenses/windows.md b/licenses/windows.md index e8bcc932f332f..33c142550ddcd 100644 --- a/licenses/windows.md +++ b/licenses/windows.md @@ -28,9 +28,9 @@ Windows][]. See also the dependencies in the [Tailscale CLI][]. - [github.com/hdevalence/ed25519consensus](https://pkg.go.dev/github.com/hdevalence/ed25519consensus) ([BSD-3-Clause](https://github.com/hdevalence/ed25519consensus/blob/v0.2.0/LICENSE)) - [github.com/jellydator/ttlcache/v3](https://pkg.go.dev/github.com/jellydator/ttlcache/v3) ([MIT](https://github.com/jellydator/ttlcache/blob/v3.1.0/LICENSE)) - [github.com/jsimonetti/rtnetlink](https://pkg.go.dev/github.com/jsimonetti/rtnetlink) ([MIT](https://github.com/jsimonetti/rtnetlink/blob/v1.4.1/LICENSE.md)) - - [github.com/klauspost/compress](https://pkg.go.dev/github.com/klauspost/compress) ([Apache-2.0](https://github.com/klauspost/compress/blob/v1.18.2/LICENSE)) - - [github.com/klauspost/compress/internal/snapref](https://pkg.go.dev/github.com/klauspost/compress/internal/snapref) ([BSD-3-Clause](https://github.com/klauspost/compress/blob/v1.18.2/internal/snapref/LICENSE)) - - [github.com/klauspost/compress/zstd/internal/xxhash](https://pkg.go.dev/github.com/klauspost/compress/zstd/internal/xxhash) ([MIT](https://github.com/klauspost/compress/blob/v1.18.2/zstd/internal/xxhash/LICENSE.txt)) + - [github.com/klauspost/compress](https://pkg.go.dev/github.com/klauspost/compress) ([Apache-2.0](https://github.com/klauspost/compress/blob/v1.18.5/LICENSE)) + - [github.com/klauspost/compress/internal/snapref](https://pkg.go.dev/github.com/klauspost/compress/internal/snapref) ([BSD-3-Clause](https://github.com/klauspost/compress/blob/v1.18.5/internal/snapref/LICENSE)) + - [github.com/klauspost/compress/zstd/internal/xxhash](https://pkg.go.dev/github.com/klauspost/compress/zstd/internal/xxhash) ([MIT](https://github.com/klauspost/compress/blob/v1.18.5/zstd/internal/xxhash/LICENSE.txt)) - [github.com/mdlayher/netlink](https://pkg.go.dev/github.com/mdlayher/netlink) ([MIT](https://github.com/mdlayher/netlink/blob/fbb4dce95f42/LICENSE.md)) - [github.com/mdlayher/socket](https://pkg.go.dev/github.com/mdlayher/socket) ([MIT](https://github.com/mdlayher/socket/blob/v0.5.0/LICENSE.md)) - [github.com/mitchellh/go-ps](https://pkg.go.dev/github.com/mitchellh/go-ps) ([MIT](https://github.com/mitchellh/go-ps/blob/v1.0.0/LICENSE.md)) @@ -42,7 +42,7 @@ Windows][]. See also the dependencies in the [Tailscale CLI][]. - [github.com/prometheus/common](https://pkg.go.dev/github.com/prometheus/common) ([Apache-2.0](https://github.com/prometheus/common/blob/v0.67.5/LICENSE)) - [github.com/skip2/go-qrcode](https://pkg.go.dev/github.com/skip2/go-qrcode) ([MIT](https://github.com/skip2/go-qrcode/blob/da1b6568686e/LICENSE)) - [github.com/tailscale/go-winio](https://pkg.go.dev/github.com/tailscale/go-winio) ([MIT](https://github.com/tailscale/go-winio/blob/c4f33415bf55/LICENSE)) - - [github.com/tailscale/hujson](https://pkg.go.dev/github.com/tailscale/hujson) ([BSD-3-Clause](https://github.com/tailscale/hujson/blob/992244df8c5a/LICENSE)) + - [github.com/tailscale/hujson](https://pkg.go.dev/github.com/tailscale/hujson) ([BSD-3-Clause](https://github.com/tailscale/hujson/blob/ecc657c15afd/LICENSE)) - [github.com/tailscale/walk](https://pkg.go.dev/github.com/tailscale/walk) ([BSD-3-Clause](https://github.com/tailscale/walk/blob/963e260a8227/LICENSE)) - [github.com/tailscale/win](https://pkg.go.dev/github.com/tailscale/win) ([BSD-3-Clause](https://github.com/tailscale/win/blob/f4da2b8ee071/LICENSE)) - [github.com/tailscale/xnet/webdav](https://pkg.go.dev/github.com/tailscale/xnet/webdav) ([BSD-3-Clause](https://github.com/tailscale/xnet/blob/8497ac4dab2e/LICENSE)) @@ -51,14 +51,14 @@ Windows][]. See also the dependencies in the [Tailscale CLI][]. - [go.yaml.in/yaml/v2](https://pkg.go.dev/go.yaml.in/yaml/v2) ([Apache-2.0](https://github.com/yaml/go-yaml/blob/v2.4.3/LICENSE)) - [go4.org/mem](https://pkg.go.dev/go4.org/mem) ([Apache-2.0](https://github.com/go4org/mem/blob/ae6ca9944745/LICENSE)) - [go4.org/netipx](https://pkg.go.dev/go4.org/netipx) ([BSD-3-Clause](https://github.com/go4org/netipx/blob/fdeea329fbba/LICENSE)) - - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.47.0:LICENSE)) - - [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/a4bb9ffd:LICENSE)) + - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.50.0:LICENSE)) + - [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/3dfff04d:LICENSE)) - [golang.org/x/image/bmp](https://pkg.go.dev/golang.org/x/image/bmp) ([BSD-3-Clause](https://cs.opensource.google/go/x/image/+/v0.27.0:LICENSE)) - - [golang.org/x/mod](https://pkg.go.dev/golang.org/x/mod) ([BSD-3-Clause](https://cs.opensource.google/go/x/mod/+/v0.32.0:LICENSE)) - - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.49.0:LICENSE)) - - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.19.0:LICENSE)) - - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.40.0:LICENSE)) - - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.39.0:LICENSE)) + - [golang.org/x/mod](https://pkg.go.dev/golang.org/x/mod) ([BSD-3-Clause](https://cs.opensource.google/go/x/mod/+/v0.35.0:LICENSE)) + - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.53.0:LICENSE)) + - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.20.0:LICENSE)) + - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.43.0:LICENSE)) + - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.42.0:LICENSE)) - [golang.zx2c4.com/wintun](https://pkg.go.dev/golang.zx2c4.com/wintun) ([MIT](https://git.zx2c4.com/wintun-go/tree/LICENSE?id=0fa3db229ce2)) - [golang.zx2c4.com/wireguard/windows/tunnel/winipcfg](https://pkg.go.dev/golang.zx2c4.com/wireguard/windows/tunnel/winipcfg) ([MIT](https://git.zx2c4.com/wireguard-windows/tree/COPYING?h=v0.5.3)) - [google.golang.org/protobuf](https://pkg.go.dev/google.golang.org/protobuf) ([BSD-3-Clause](https://github.com/protocolbuffers/protobuf-go/blob/v1.36.11/LICENSE)) diff --git a/log/sockstatlog/logger_test.go b/log/sockstatlog/logger_test.go index 66228731e368e..9176cfe53ab57 100644 --- a/log/sockstatlog/logger_test.go +++ b/log/sockstatlog/logger_test.go @@ -39,19 +39,19 @@ func TestDelta(t *testing.T) { wantStats map[sockstats.Label]deltaStat }{ { - name: "nil a stat", + name: "nil-a-stat", a: nil, b: &sockstats.SockStats{}, wantStats: nil, }, { - name: "nil b stat", + name: "nil-b-stat", a: &sockstats.SockStats{}, b: nil, wantStats: nil, }, { - name: "no change", + name: "no-change", a: &sockstats.SockStats{ Stats: map[sockstats.Label]sockstats.SockStat{ sockstats.LabelDERPHTTPClient: { @@ -69,7 +69,7 @@ func TestDelta(t *testing.T) { wantStats: nil, }, { - name: "tx after empty stat", + name: "tx-after-empty-stat", a: &sockstats.SockStats{}, b: &sockstats.SockStats{ Stats: map[sockstats.Label]sockstats.SockStat{ @@ -83,7 +83,7 @@ func TestDelta(t *testing.T) { }, }, { - name: "rx after non-empty stat", + name: "rx-after-non-empty-stat", a: &sockstats.SockStats{ Stats: map[sockstats.Label]sockstats.SockStat{ sockstats.LabelDERPHTTPClient: { diff --git a/logtail/config.go b/logtail/config.go index c504047a3f2bf..0ee5999059fb5 100644 --- a/logtail/config.go +++ b/logtail/config.go @@ -64,4 +64,12 @@ type Config struct { // being included in the logs. The sequence number is incremented for each // log message sent, but is not persisted across process restarts. IncludeProcSequence bool + + // Disabled, if true, causes the returned [Logger] to start in the + // disabled state, dropping entries without buffering or uploading + // (equivalent to calling [Logger.SetEnabled] with false immediately). + // It applies before the internal startup banner is written, so no + // log entries are emitted until [Logger.SetEnabled] is called with + // true. The process-wide [Disable] kill switch still takes precedence. + Disabled bool } diff --git a/logtail/logtail.go b/logtail/logtail.go index ef296568da957..a45f1bfe9da8b 100644 --- a/logtail/logtail.go +++ b/logtail/logtail.go @@ -132,6 +132,7 @@ func NewLogger(cfg Config, logf tslogger.Logf) *Logger { } logger.SetSockstatsLabel(sockstats.LabelLogtailLogger) logger.compressLogs = cfg.CompressLogs + logger.disabled.Store(cfg.Disabled) ctx, cancel := context.WithCancel(context.Background()) logger.uploadCancel = cancel @@ -172,6 +173,11 @@ type Logger struct { procID uint32 includeProcSequence bool + // disabled, when true, causes this logger to drop incoming log entries + // without buffering or uploading. It is independent of the process-wide + // Disable kill switch, which takes precedence. Toggled by SetEnabled. + disabled atomic.Bool + writeLock sync.Mutex // guards procSequence, flushTimer, buffer.Write calls procSequence uint64 flushTimer tstime.TimerController // used when flushDelay is >0 @@ -594,6 +600,15 @@ func Disable() { logtailDisabled.Store(true) } +// SetEnabled enables or disables log uploading by lg. When disabled, log +// entries passed to lg are dropped rather than buffered or uploaded; already +// buffered entries may still drain. The process-wide [Disable] kill switch +// takes precedence: if Disable has been called, SetEnabled(true) does not +// re-enable uploads. +func (lg *Logger) SetEnabled(enabled bool) { + lg.disabled.Store(!enabled) +} + var debugWakesAndUploads = envknob.RegisterBool("TS_DEBUG_LOGTAIL_WAKES") // tryDrainWake tries to send to lg.drainWake, to cause an uploading wakeup. @@ -613,7 +628,7 @@ func (lg *Logger) tryDrainWake() { func (lg *Logger) sendLocked(jsonBlob []byte) (int, error) { tapSend(jsonBlob) - if logtailDisabled.Load() { + if logtailDisabled.Load() || lg.disabled.Load() { return len(jsonBlob), nil } @@ -902,8 +917,8 @@ func parseAndRemoveLogLevel(buf []byte) (level int, cleanBuf []byte) { if bytes.Contains(buf, v2) { return 2, bytes.ReplaceAll(buf, v2, nil) } - if i := bytes.Index(buf, vJSON); i != -1 { - rest := buf[i+len(vJSON):] + if _, after, ok := bytes.Cut(buf, vJSON); ok { + rest := after if len(rest) >= 2 { v := rest[0] if v >= '0' && v <= '9' { diff --git a/logtail/logtail_omit.go b/logtail/logtail_omit.go index 21f18c980cce4..98f1c6a0e5d6b 100644 --- a/logtail/logtail_omit.go +++ b/logtail/logtail_omit.go @@ -20,6 +20,8 @@ type Buffer any func Disable() {} +func (*Logger) SetEnabled(enabled bool) {} + func NewLogger(cfg Config, logf tslogger.Logf) *Logger { return &Logger{} } diff --git a/logtail/logtail_test.go b/logtail/logtail_test.go index 67250ae0db03f..8273097c3aebf 100644 --- a/logtail/logtail_test.go +++ b/logtail/logtail_test.go @@ -7,34 +7,51 @@ import ( "bytes" "context" "encoding/json" + "fmt" "io" + "net" "net/http" - "net/http/httptest" + "os" "strings" + "sync" "testing" + "testing/synctest" "time" "github.com/go-json-experiment/json/jsontext" + "tailscale.com/net/memnet" "tailscale.com/tstest" "tailscale.com/tstime" "tailscale.com/util/eventbus/eventbustest" "tailscale.com/util/must" ) -func TestFastShutdown(t *testing.T) { +// TestMain installs a safety net that refuses non-localhost dials for any +// test in this package. Config.BaseURL defaults to https://log.tailscale.com +// and Config.HTTPC defaults to http.DefaultClient, so a test that forgets to +// override either can otherwise silently hit the real logtail server. +// Tests that need an HTTP server should use memnet (see newTestLogtailServer). +func TestMain(m *testing.M) { + tr := http.DefaultTransport.(*http.Transport) + orig := tr.DialContext + tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + host, _, err := net.SplitHostPort(addr) + if err == nil && (host == "127.0.0.1" || host == "::1" || host == "localhost") { + return orig(ctx, network, addr) + } + return nil, fmt.Errorf("logtail tests: refusing to dial non-localhost address %q; use memnet or a custom Config.HTTPC", addr) + } + os.Exit(m.Run()) +} + +func TestFastShutdown(t *testing.T) { synctest.Test(t, synctestFastShutdown) } + +func synctestFastShutdown(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() - testServ := httptest.NewServer(http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) {})) - defer testServ.Close() - - logger := NewLogger(Config{ - BaseURL: testServ.URL, - Bus: eventbustest.NewBus(t), - }, t.Logf) - err := logger.Shutdown(ctx) - if err != nil { + _, logger := newTestLogtailServer(t) + if err := logger.Shutdown(ctx); err != nil { t.Error(err) } } @@ -43,67 +60,78 @@ func TestFastShutdown(t *testing.T) { const logLines = 3 type LogtailTestServer struct { - srv *httptest.Server // Log server uploaded chan []byte } -func NewLogtailTestHarness(t *testing.T) (*LogtailTestServer, *Logger) { - ts := LogtailTestServer{} - - // max channel backlog = 1 "started" + #logLines x "log line" + 1 "closed" - ts.uploaded = make(chan []byte, 2+logLines) +// newTestLogtailServer wires up an in-memory HTTP server (via memnet) and a +// *Logger whose HTTPC dials it. Lives inside the caller's synctest bubble so +// the default FlushDelay and any other fake timers advance automatically. +func newTestLogtailServer(t *testing.T) (*LogtailTestServer, *Logger) { + ts := &LogtailTestServer{ + // max channel backlog = 1 "started" + #logLines x "log line" + 1 "closed" + uploaded: make(chan []byte, 2+logLines), + } - ts.srv = httptest.NewServer(http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - body, err := io.ReadAll(r.Body) - if err != nil { - t.Error("failed to read HTTP request") - } - ts.uploaded <- body - })) + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + t.Error("failed to read HTTP request") + } + ts.uploaded <- body + }) - t.Cleanup(ts.srv.Close) + ln := memnet.Listen("logtail-test:0") + httpsrv := &http.Server{Handler: handler} + go httpsrv.Serve(ln) + t.Cleanup(func() { + httpsrv.Close() + ln.Close() + }) logger := NewLogger(Config{ - BaseURL: ts.srv.URL, + BaseURL: "http://" + ln.Addr().String(), Bus: eventbustest.NewBus(t), + HTTPC: &http.Client{ + Transport: &http.Transport{DialContext: ln.Dial}, + }, }, t.Logf) - // There is always an initial "logtail started" message + // There is always an initial "logtail started" message. body := <-ts.uploaded if !strings.Contains(string(body), "started") { t.Errorf("unknown start logging statement: %q", string(body)) } - - return &ts, logger + return ts, logger } -func TestDrainPendingMessages(t *testing.T) { - ts, logger := NewLogtailTestHarness(t) +func TestDrainPendingMessages(t *testing.T) { synctest.Test(t, synctestDrainPendingMessages) } + +func synctestDrainPendingMessages(t *testing.T) { + ts, logger := newTestLogtailServer(t) for range logLines { logger.Write([]byte("log line")) } - // all of the "log line" messages usually arrive at once, but poll if needed. - body := "" + // All the "log line" messages usually arrive at once, but poll if needed. + var body strings.Builder for i := 0; i <= logLines; i++ { - body += string(<-ts.uploaded) - count := strings.Count(body, "log line") + body.WriteString(string(<-ts.uploaded)) + count := strings.Count(body.String(), "log line") if count == logLines { break } - // if we never find count == logLines, the test will eventually time out. } - err := logger.Shutdown(context.Background()) - if err != nil { + if err := logger.Shutdown(context.Background()); err != nil { t.Error(err) } } -func TestEncodeAndUploadMessages(t *testing.T) { - ts, logger := NewLogtailTestHarness(t) +func TestEncodeAndUploadMessages(t *testing.T) { synctest.Test(t, synctestEncodeAndUploadMessages) } + +func synctestEncodeAndUploadMessages(t *testing.T) { + ts, logger := newTestLogtailServer(t) tests := []struct { name string @@ -144,8 +172,7 @@ func TestEncodeAndUploadMessages(t *testing.T) { } } - err := logger.Shutdown(context.Background()) - if err != nil { + if err := logger.Shutdown(context.Background()); err != nil { t.Error(err) } } @@ -321,6 +348,90 @@ func TestLoggerWriteResult(t *testing.T) { } } +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) } + +func TestNewLoggerDisabled(t *testing.T) { synctest.Test(t, synctestNewLoggerDisabled) } + +func synctestNewLoggerDisabled(t *testing.T) { + // When Config.Disabled is true, NewLogger must not emit the usual + // "logtail started" banner: the logger should start in the disabled + // state before the internal startup write, so nothing ever lands + // in the buffer for the upload goroutine to drain. + buf := NewMemoryBuffer(100) + + // Any HTTP attempt indicates the banner leaked into the buffer and + // the upload goroutine tried to ship it. Report it once (so the + // retry spin doesn't drown the log), then block on the request + // context so synctest.Wait sees a durable block and Shutdown's + // uploadCancel can unblock us cleanly. + var once sync.Once + httpc := &http.Client{ + Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) { + once.Do(func() { + t.Errorf("unexpected HTTP request while Disabled=true: %s", r.URL) + }) + <-r.Context().Done() + return nil, r.Context().Err() + }), + } + + logger := NewLogger(Config{ + BaseURL: "http://logtail.test.invalid", + HTTPC: httpc, + Bus: eventbustest.NewBus(t), + Buffer: buf, + Disabled: true, + }, t.Logf) + defer func() { + // Pass an already-cancelled context so Shutdown invokes + // uploadCancel immediately; otherwise on the regression path + // (Disabled=false) the upload goroutine stays in its retry + // loop and synctest.Test never returns. + ctx, cancel := context.WithCancel(context.Background()) + cancel() + logger.Shutdown(ctx) + }() + + synctest.Wait() + + if back, _ := buf.TryReadLine(); len(back) != 0 { + t.Errorf("Disabled logger buffered a startup entry: %q", back) + } +} + +func TestLoggerSetEnabled(t *testing.T) { + buf := NewMemoryBuffer(100) + lg := &Logger{ + clock: tstest.NewClock(tstest.ClockOpts{Start: time.Unix(123, 0)}), + buffer: buf, + } + + if _, err := lg.Write([]byte("enabled1")); err != nil { + t.Fatal(err) + } + if back, _ := buf.TryReadLine(); !strings.Contains(string(back), "enabled1") { + t.Fatalf("initial write not buffered; got %q", back) + } + + lg.SetEnabled(false) + if _, err := lg.Write([]byte("disabled")); err != nil { + t.Fatal(err) + } + if back, _ := buf.TryReadLine(); len(back) != 0 { + t.Errorf("write while disabled leaked into buffer: %q", back) + } + + lg.SetEnabled(true) + if _, err := lg.Write([]byte("enabled2")); err != nil { + t.Fatal(err) + } + if back, _ := buf.TryReadLine(); !strings.Contains(string(back), "enabled2") { + t.Errorf("write after re-enable not buffered; got %q", back) + } +} + func TestAppendMetadata(t *testing.T) { var lg Logger lg.clock = tstest.NewClock(tstest.ClockOpts{Start: time.Date(2000, 01, 01, 0, 0, 0, 0, time.UTC)}) diff --git a/maths/ewma.go b/maths/ewma.go deleted file mode 100644 index 1946081cf6d08..0000000000000 --- a/maths/ewma.go +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -// Package maths contains additional mathematical functions or structures not -// found in the standard library. -package maths - -import ( - "math" - "time" -) - -// EWMA is an exponentially weighted moving average supporting updates at -// irregular intervals with at most nanosecond resolution. -// The zero value will compute a half-life of 1 second. -// It is not safe for concurrent use. -// TODO(raggi): de-duplicate with tstime/rate.Value, which has a more complex -// and synchronized interface and does not provide direct access to the stable -// value. -type EWMA struct { - value float64 // current value of the average - lastTime int64 // time of last update in unix nanos - halfLife float64 // half-life in seconds -} - -// NewEWMA creates a new EWMA with the specified half-life. If halfLifeSeconds -// is 0, it defaults to 1. -func NewEWMA(halfLifeSeconds float64) *EWMA { - return &EWMA{ - halfLife: halfLifeSeconds, - } -} - -// Update adds a new sample to the average. If t is zero or precedes the last -// update, the update is ignored. -func (e *EWMA) Update(value float64, t time.Time) { - if t.IsZero() { - return - } - hl := e.halfLife - if hl == 0 { - hl = 1 - } - tn := t.UnixNano() - if e.lastTime == 0 { - e.value = value - e.lastTime = tn - return - } - - dt := (time.Duration(tn-e.lastTime) * time.Nanosecond).Seconds() - if dt < 0 { - // drop out of order updates - return - } - - // decay = 2^(-dt/halfLife) - decay := math.Exp2(-dt / hl) - e.value = e.value*decay + value*(1-decay) - e.lastTime = tn -} - -// Get returns the current value of the average -func (e *EWMA) Get() float64 { - return e.value -} - -// Reset clears the EWMA to its initial state -func (e *EWMA) Reset() { - e.value = 0 - e.lastTime = 0 -} diff --git a/maths/ewma_test.go b/maths/ewma_test.go deleted file mode 100644 index 9fddf34e17193..0000000000000 --- a/maths/ewma_test.go +++ /dev/null @@ -1,178 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -package maths - -import ( - "slices" - "testing" - "time" -) - -// some real world latency samples. -var ( - latencyHistory1 = []int{ - 14, 12, 15, 6, 19, 12, 13, 13, 13, 16, 17, 11, 17, 11, 14, 15, 14, 15, - 16, 16, 17, 14, 12, 16, 18, 14, 14, 11, 15, 15, 25, 11, 15, 14, 12, 15, - 13, 12, 13, 15, 11, 13, 15, 14, 14, 15, 12, 15, 18, 12, 15, 22, 12, 13, - 10, 14, 16, 15, 16, 11, 14, 17, 18, 20, 16, 11, 16, 14, 5, 15, 17, 12, - 15, 11, 15, 20, 12, 17, 12, 17, 15, 12, 12, 11, 14, 15, 11, 20, 14, 13, - 11, 12, 13, 13, 11, 13, 11, 15, 13, 13, 14, 12, 11, 12, 12, 14, 11, 13, - 12, 12, 12, 19, 14, 13, 13, 14, 11, 12, 10, 11, 15, 12, 14, 11, 11, 14, - 14, 12, 12, 11, 14, 12, 11, 12, 14, 11, 12, 15, 12, 14, 12, 12, 21, 16, - 21, 12, 16, 9, 11, 16, 14, 13, 14, 12, 13, 16, - } - latencyHistory2 = []int{ - 18, 20, 21, 21, 20, 23, 18, 18, 20, 21, 20, 19, 22, 18, 20, 20, 19, 21, - 21, 22, 22, 19, 18, 22, 22, 19, 20, 17, 16, 11, 25, 16, 18, 21, 17, 22, - 19, 18, 22, 21, 20, 18, 22, 17, 17, 20, 19, 10, 19, 16, 19, 25, 17, 18, - 15, 20, 21, 20, 23, 22, 22, 22, 19, 22, 22, 17, 22, 20, 20, 19, 21, 22, - 20, 19, 17, 22, 16, 16, 20, 22, 17, 19, 21, 16, 20, 22, 19, 21, 20, 19, - 13, 14, 23, 19, 16, 10, 19, 15, 15, 17, 16, 18, 14, 16, 18, 22, 20, 18, - 18, 21, 15, 19, 18, 19, 18, 20, 17, 19, 21, 19, 20, 19, 20, 20, 17, 14, - 17, 17, 18, 21, 20, 18, 18, 17, 16, 17, 17, 20, 22, 19, 20, 21, 21, 20, - 21, 24, 20, 18, 12, 17, 18, 17, 19, 19, 19, - } -) - -func TestEWMALatencyHistory(t *testing.T) { - type result struct { - t time.Time - v float64 - s int - } - - for _, latencyHistory := range [][]int{latencyHistory1, latencyHistory2} { - startTime := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) - halfLife := 30.0 - - ewma := NewEWMA(halfLife) - - var results []result - sum := 0.0 - for i, latency := range latencyHistory { - t := startTime.Add(time.Duration(i) * time.Second) - ewma.Update(float64(latency), t) - sum += float64(latency) - - results = append(results, result{t, ewma.Get(), latency}) - } - mean := sum / float64(len(latencyHistory)) - min := float64(slices.Min(latencyHistory)) - max := float64(slices.Max(latencyHistory)) - - t.Logf("EWMA Latency History (half-life: %.1f seconds):", halfLife) - t.Logf("Mean latency: %.2f ms", mean) - t.Logf("Range: [%.1f, %.1f]", min, max) - - t.Log("Samples: ") - sparkline := []rune("▁▂▃▄▅▆▇█") - var sampleLine []rune - for _, r := range results { - idx := int(((float64(r.s) - min) / (max - min)) * float64(len(sparkline)-1)) - if idx >= len(sparkline) { - idx = len(sparkline) - 1 - } - sampleLine = append(sampleLine, sparkline[idx]) - } - t.Log(string(sampleLine)) - - t.Log("EWMA: ") - var ewmaLine []rune - for _, r := range results { - idx := int(((r.v - min) / (max - min)) * float64(len(sparkline)-1)) - if idx >= len(sparkline) { - idx = len(sparkline) - 1 - } - ewmaLine = append(ewmaLine, sparkline[idx]) - } - t.Log(string(ewmaLine)) - t.Log("") - - t.Logf("Time | Sample | Value | Value - Sample") - t.Logf("") - - for _, result := range results { - t.Logf("%10s | % 6d | % 5.2f | % 5.2f", result.t.Format("15:04:05"), result.s, result.v, result.v-float64(result.s)) - } - - // check that all results are greater than the min, and less than the max of the input, - // and they're all close to the mean. - for _, result := range results { - if result.v < float64(min) || result.v > float64(max) { - t.Errorf("result %f out of range [%f, %f]", result.v, min, max) - } - - if result.v < mean*0.9 || result.v > mean*1.1 { - t.Errorf("result %f not close to mean %f", result.v, mean) - } - } - } -} - -func TestHalfLife(t *testing.T) { - start := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) - - ewma := NewEWMA(30.0) - ewma.Update(10, start) - ewma.Update(0, start.Add(30*time.Second)) - - if ewma.Get() != 5 { - t.Errorf("expected 5, got %f", ewma.Get()) - } - - ewma.Update(10, start.Add(60*time.Second)) - if ewma.Get() != 7.5 { - t.Errorf("expected 7.5, got %f", ewma.Get()) - } - - ewma.Update(10, start.Add(90*time.Second)) - if ewma.Get() != 8.75 { - t.Errorf("expected 8.75, got %f", ewma.Get()) - } -} - -func TestZeroValue(t *testing.T) { - start := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) - - var ewma EWMA - ewma.Update(10, start) - ewma.Update(0, start.Add(time.Second)) - - if ewma.Get() != 5 { - t.Errorf("expected 5, got %f", ewma.Get()) - } - - ewma.Update(10, start.Add(2*time.Second)) - if ewma.Get() != 7.5 { - t.Errorf("expected 7.5, got %f", ewma.Get()) - } - - ewma.Update(10, start.Add(3*time.Second)) - if ewma.Get() != 8.75 { - t.Errorf("expected 8.75, got %f", ewma.Get()) - } -} - -func TestReset(t *testing.T) { - start := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) - - ewma := NewEWMA(30.0) - ewma.Update(10, start) - ewma.Update(0, start.Add(30*time.Second)) - - if ewma.Get() != 5 { - t.Errorf("expected 5, got %f", ewma.Get()) - } - - ewma.Reset() - - if ewma.Get() != 0 { - t.Errorf("expected 0, got %f", ewma.Get()) - } - - ewma.Update(10, start.Add(90*time.Second)) - if ewma.Get() != 10 { - t.Errorf("expected 10, got %f", ewma.Get()) - } -} diff --git a/metrics/multilabelmap.go b/metrics/multilabelmap.go index 54d41bbae9ef2..fa31819d9c3f8 100644 --- a/metrics/multilabelmap.go +++ b/metrics/multilabelmap.go @@ -63,16 +63,16 @@ func LabelString(k any) string { var sb strings.Builder sb.WriteString("{") - for i := range t.NumField() { - if i > 0 { + first := true + for ft, fv := range rv.Fields() { + if !first { sb.WriteString(",") } - ft := t.Field(i) + first = false label := ft.Tag.Get("prom") if label == "" { label = strings.ToLower(ft.Name) } - fv := rv.Field(i) switch fv.Kind() { case reflect.String: fmt.Fprintf(&sb, "%s=%q", label, fv.String()) diff --git a/metrics/multilabelmap_test.go b/metrics/multilabelmap_test.go index 70554c63e50a0..0fa730992bc16 100644 --- a/metrics/multilabelmap_test.go +++ b/metrics/multilabelmap_test.go @@ -86,10 +86,10 @@ metricname{foo="si",bar="si"} 5 func TestMultiLabelMapTypes(t *testing.T) { type LabelTypes struct { - S string - B bool - I int - U uint + S string + B bool + Int int + U uint } m := new(MultiLabelMap[LabelTypes]) @@ -100,7 +100,7 @@ func TestMultiLabelMapTypes(t *testing.T) { m.WritePrometheus(&buf, "metricname") const want = `# TYPE metricname counter # HELP metricname some good stuff -metricname{s="a",b="true",i="-1",u="2"} 3 +metricname{s="a",b="true",int="-1",u="2"} 3 ` if got := buf.String(); got != want { t.Errorf("got %q; want %q", got, want) diff --git a/misc/genreadme/genreadme.go b/misc/genreadme/genreadme.go new file mode 100644 index 0000000000000..97a8d9e1640cc --- /dev/null +++ b/misc/genreadme/genreadme.go @@ -0,0 +1,267 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// The genreadme tool generates/updates README.md files in the tailscale repo. +// +// # Running +// +// From the repo root, run: `./tool/go run ./misc/genreadme` and it will update all +// the README.md files that are stale in the tree. +package main + +import ( + "bytes" + "errors" + "flag" + "fmt" + "go/parser" + "go/token" + "io" + "io/fs" + "log" + "os" + "path" + "path/filepath" + "runtime" + "strings" + + "github.com/creachadair/taskgroup" + "tailscale.com/tempfork/pkgdoc" +) + +// modulePath is the current module's import path, read from go.mod at startup. +var modulePath string + +var skip = map[string]bool{ + "out": true, +} + +// bkSkip lists directories where the generated file should not mention +// Buildkite because a deploy workflow is not set up for them. +var bkSkip = map[string]bool{} + +// defaultRoots are the directory trees walked when genreadme is run with +// no arguments. Add a directory here to opt its package (and any +// sub-packages) into README.md generation from godoc. +var defaultRoots = []string{ + "tsnet", +} + +func main() { + flag.Parse() + modulePath = readModulePath("go.mod") + var roots []string + switch flag.NArg() { + case 0: + roots = defaultRoots + case 1: + root := flag.Arg(0) + root = strings.TrimPrefix(root, "./") + root = strings.TrimSuffix(root, "/") + roots = []string{root} + default: + log.Fatalf("Usage: genreadme [dir]") + } + + var updateErrs []error + g, run := taskgroup.New(func(err error) { + updateErrs = append(updateErrs, err) + }).Limit(runtime.NumCPU() * 2) // usually I/O bound + + for _, root := range roots { + g.Go(func() error { + return fs.WalkDir(os.DirFS("."), root, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if !d.IsDir() { + return nil + } + if skip[path] { + return fs.SkipDir + } + base := filepath.Base(path) + if base == "testdata" || (path != "." && base[0] == '.') { + return fs.SkipDir + } + run(func() error { + return update(path) + }) + return nil + }) + }) + } + g.Wait() + if err := errors.Join(updateErrs...); err != nil { + log.Fatal(err) + } +} + +func update(dir string) error { + readmePath := filepath.Join(dir, "README.md") + cur, err := os.ReadFile(readmePath) + exists := false + if err != nil && !os.IsNotExist(err) { + return err + } + if err == nil { + exists = true + if !isGenerated(cur) { + // Do nothing; a human wrote this file. + return nil + } + } + + newContents, err := getNewContent(dir) + if err != nil { + return err + } + if newContents == nil { + if exists { + log.Printf("Deleting %s ...", readmePath) + os.Remove(readmePath) + } + return nil + } + + if bytes.Equal(cur, newContents) { + return nil + } + log.Printf("Writing %s ...", readmePath) + return os.WriteFile(readmePath, newContents, 0644) +} + +func getNewContent(dir string) (newContent []byte, err error) { + dents, err := os.ReadDir(dir) + if err != nil { + return nil, err + } + + generators := []struct { + name string + quickTest func(dir string, dents []fs.DirEntry) bool + generate func(dir string) ([]byte, error) + }{ + {"go", hasGoFiles, genGoDoc}, + } + for _, gen := range generators { + if !gen.quickTest(dir, dents) { + continue + } + newContent, err := gen.generate(dir) + if newContent == nil && err == nil { + // Generator declined to generate, try next + continue + } + return newContent, err + } + return nil, nil +} + +func genGoDoc(dir string) ([]byte, error) { + abs, err := filepath.Abs(dir) + if err != nil { + return nil, fmt.Errorf("failed to get absolute path for %q: %w", dir, err) + } + var importPath string + if modulePath != "" { + importPath = path.Join(modulePath, filepath.ToSlash(dir)) + } + godoc, err := pkgdoc.PackageDoc(abs, importPath) + if err != nil { + return nil, fmt.Errorf("failed to get package doc for %q: %w", dir, err) + } + if len(bytes.TrimSpace(godoc)) == 0 { + // No godoc; skipping. + return nil, nil + } + isLibrary := bytes.HasPrefix(godoc, []byte("package ")) + if isLibrary { + // Strip the "package X // import Y\n\n" clause emitted for library packages. + if i := bytes.Index(godoc, []byte("\n\n")); i != -1 { + godoc = godoc[i+2:] + } + } + if len(bytes.TrimSpace(godoc)) == 0 { + return nil, nil + } + var buf bytes.Buffer + io.WriteString(&buf, genHeader) + fmt.Fprintf(&buf, "\n# %s\n\n", filepath.Base(dir)) + if isLibrary && importPath != "" { + fmt.Fprintf(&buf, "[![Go Reference](https://pkg.go.dev/badge/%s.svg)](https://pkg.go.dev/%s)\n\n", importPath, importPath) + } + buf.Write(godoc) + + if !bytes.Contains(godoc, []byte("## Deploying")) { + deployPath := filepath.Join(dir, "deploy.sh") + if _, err := os.Stat(deployPath); err == nil { + fmt.Fprint(&buf, "\n## Deploying\n\n") + if hasBuildkite(dir) { + fmt.Fprintf(&buf, + "To deploy, run the https://buildkite.com/tailscale/deploy-%s workflow in Buildkite.\n", + filepath.Base(dir), + ) + } + fmt.Fprintf(&buf, "To deploy manually, run `./%s` from the repo root.\n\n", deployPath) + } + } + return buf.Bytes(), nil +} + +const genHeader = "\n" + +func isGenerated(b []byte) bool { return bytes.HasPrefix(b, []byte(genHeader)) } + +// readModulePath returns the module path declared in the given go.mod file, +// or "" if it can't be read or parsed. +func readModulePath(file string) string { + b, err := os.ReadFile(file) + if err != nil { + return "" + } + for line := range strings.Lines(string(b)) { + if rest, ok := strings.CutPrefix(strings.TrimSpace(line), "module "); ok { + return strings.Trim(strings.TrimSpace(rest), `"`) + } + } + return "" +} + +func hasBuildkite(dir string) bool { + if bkSkip[dir] { + return false + } + _, flyErr := os.Stat(filepath.Join(dir, "fly.toml")) + return flyErr != nil +} + +func hasGoFiles(dir string, dents []fs.DirEntry) bool { + var fset *token.FileSet + + for _, de := range dents { + name := de.Name() + if !strings.HasSuffix(name, ".go") || + strings.HasSuffix(name, "_test.go") { + continue + } + if fset == nil { + fset = token.NewFileSet() + } + + path := filepath.Join(dir, name) + f, err := os.Open(path) + if err != nil { + continue + } + pkgFile, err := parser.ParseFile(fset, "", f, parser.PackageClauseOnly) + f.Close() + if err != nil { + // skip files with parse errors + continue + } + + return pkgFile.Name.Name != "" + } + return false +} diff --git a/misc/git_hook/README.md b/misc/git_hook/README.md new file mode 100644 index 0000000000000..49d76893792ef --- /dev/null +++ b/misc/git_hook/README.md @@ -0,0 +1,35 @@ +# git_hook + +Tailscale's git hooks. + +The shared logic lives in the `githook/` package and is also imported by +`tailscale/corp`. + +## Install + +From the repo root: + + ./tool/go run ./misc/install-git-hooks.go + +The script auto-updates in the future. + + +## Adding your own hooks + +Create an executable `.git/hooks/.local` to chain a custom +script after a built-in hook. For example, put a custom check in +`.git/hooks/pre-commit.local` and `chmod +x` it. The local hook runs +only if the built-in hook succeeds; failure aborts the git operation. + + +## Version bumps + +The launcher rebuilds when the installed binary's version differs from +the concatenation of two files: + +* `githook/HOOK_VERSION` (shared): bump when changing anything under + `githook/` or `git-hook.go`. Downstream repos pick it up after + bumping their `tailscale.com` dependency. +* `misc/git_hook/HOOK_VERSION` (repo-local, optional): bump to force a + rebuild for repo-specific config changes without touching the shared + version. This repo does not use one. diff --git a/misc/git_hook/git-hook.go b/misc/git_hook/git-hook.go new file mode 100644 index 0000000000000..2cf3ff421ccdf --- /dev/null +++ b/misc/git_hook/git-hook.go @@ -0,0 +1,62 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// The git-hook command is Tailscale's git hook binary, built and +// installed under .git/hooks/ts-git-hook-bin by the launcher at +// .git/hooks/ts-git-hook. misc/install-git-hooks.go writes the initial +// launcher; subsequent HOOK_VERSION bumps trigger self-rebuilds. +// +// # Adding your own hooks +// +// To add your own hook alongside one we already hook, create an executable +// file .git/hooks/.local (e.g. pre-commit.local). It runs after +// the built-in hook succeeds. +package main + +import ( + "fmt" + "log" + "os" + "strings" + + "tailscale.com/misc/git_hook/githook" +) + +var pushRemotes = []string{ + "git@github.com:tailscale/tailscale", + "git@github.com:tailscale/tailscale.git", + "https://github.com/tailscale/tailscale", + "https://github.com/tailscale/tailscale.git", +} + +// hooks are the hook names this binary handles. Used by install to +// write per-hook wrappers; must stay in sync with the dispatcher below. +var hooks = []string{"pre-commit", "commit-msg", "pre-push"} + +func main() { + log.SetFlags(0) + if len(os.Args) < 2 { + return + } + cmd, args := os.Args[1], os.Args[2:] + + var err error + switch cmd { + case "version": + fmt.Print(strings.TrimSpace(githook.HookVersion) + ":0") + case "install": + err = githook.WriteHooks(hooks) + case "pre-commit": + err = githook.CheckForbiddenMarkers() + case "commit-msg": + err = githook.AddChangeID(args) + case "pre-push": + err = githook.CheckGoModReplaces(args, pushRemotes, nil) + } + if err != nil { + log.Fatalf("git-hook: %v: %v", cmd, err) + } + if err := githook.RunLocalHook(cmd, args); err != nil { + log.Fatalf("git-hook: %v", err) + } +} diff --git a/misc/git_hook/githook/HOOK_VERSION b/misc/git_hook/githook/HOOK_VERSION new file mode 100644 index 0000000000000..00750edc07d64 --- /dev/null +++ b/misc/git_hook/githook/HOOK_VERSION @@ -0,0 +1 @@ +3 diff --git a/misc/git_hook/githook/commit-msg.go b/misc/git_hook/githook/commit-msg.go new file mode 100644 index 0000000000000..e75bc79f39462 --- /dev/null +++ b/misc/git_hook/githook/commit-msg.go @@ -0,0 +1,64 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package githook + +import ( + "bytes" + "crypto/rand" + "errors" + "fmt" + "io" + "os" + "os/exec" +) + +// AddChangeID strips comments from the commit message at args[0] and +// prepends a random Change-Id trailer. +// +// Intended as a commit-msg hook. +// https://git-scm.com/docs/githooks#_commit_msg +func AddChangeID(args []string) error { + if len(args) != 1 { + return errors.New("usage: commit-msg message.txt") + } + file := args[0] + msg, err := os.ReadFile(file) + if err != nil { + return err + } + msg = filterCutLine(msg) + + var id [20]byte + if _, err := io.ReadFull(rand.Reader, id[:]); err != nil { + return fmt.Errorf("could not generate Change-Id: %v", err) + } + cmdLines := [][]string{ + {"git", "stripspace", "--strip-comments"}, + {"git", "interpret-trailers", "--no-divider", "--where=start", "--if-exists", "doNothing", "--trailer", fmt.Sprintf("Change-Id: I%x", id)}, + } + for _, cmdLine := range cmdLines { + if len(msg) == 0 { + // Don't let commands turn an empty message into a non-empty one (issue 2205). + break + } + cmd := exec.Command(cmdLine[0], cmdLine[1:]...) + cmd.Stdin = bytes.NewReader(msg) + msg, err = cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("failed to run %v: %w\n%s", cmd, err, msg) + } + } + return os.WriteFile(file, msg, 0666) +} + +var gitCutLine = []byte("# ------------------------ >8 ------------------------") + +// filterCutLine strips a `git commit -v`-style cutline and everything +// after it from msg. +func filterCutLine(msg []byte) []byte { + if before, _, ok := bytes.Cut(msg, gitCutLine); ok { + return before + } + return msg +} diff --git a/misc/git_hook/githook/githook.go b/misc/git_hook/githook/githook.go new file mode 100644 index 0000000000000..aa44051fc2675 --- /dev/null +++ b/misc/git_hook/githook/githook.go @@ -0,0 +1,52 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// Package githook contains the shared implementation of Tailscale's git +// hooks. The tailscale/tailscale and tailscale/corp repositories each have +// a thin main package that dispatches to this one, calling individual +// hook functions with per-repo arguments as needed. +package githook + +import ( + _ "embed" + "errors" + "fmt" + "os" + "os/exec" + "path/filepath" +) + +// Launcher is the canonical bytes of launcher.sh. Downstream repos +// (e.g. tailscale/corp) rely on these bytes at install time. +// +//go:embed launcher.sh +var Launcher []byte + +// HookVersion is the shared version of this package and launcher.sh. +// Bump HOOK_VERSION on any change under this package. +// +//go:embed HOOK_VERSION +var HookVersion string + +// RunLocalHook runs an optional user-supplied hook at +// .git/hooks/.local, if present. +func RunLocalHook(hookName string, args []string) error { + cmdPath, err := os.Executable() + if err != nil { + return err + } + localHookPath := filepath.Join(filepath.Dir(cmdPath), hookName+".local") + if _, err := os.Stat(localHookPath); errors.Is(err, os.ErrNotExist) { + return nil + } else if err != nil { + return fmt.Errorf("checking for local hook: %w", err) + } + + cmd := exec.Command(localHookPath, args...) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf("running local hook %q: %w", localHookPath, err) + } + return nil +} diff --git a/misc/git_hook/githook/install.go b/misc/git_hook/githook/install.go new file mode 100644 index 0000000000000..3c08daf8d7e6a --- /dev/null +++ b/misc/git_hook/githook/install.go @@ -0,0 +1,177 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package githook + +import ( + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "regexp" + "strings" +) + +// Install writes the launcher to .git/hooks/ts-git-hook and runs it +// once with "version", bootstrapping the binary build and per-hook +// wrappers. Called from each repo's misc/install-git-hooks.go. +func Install() error { + hookDir, err := findHookDir() + if err != nil { + return err + } + target := filepath.Join(hookDir, "ts-git-hook") + if err := writeLauncher(target); err != nil { + return err + } + + // The launcher execs the binary with our arg at the end; we pass + // "version" only to trigger the rebuild-if-stale path, and discard + // its stdout so the version string doesn't leak to the caller. + cmd := exec.Command(target, "version") + cmd.Stdout = io.Discard + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf("initial hook setup failed: %v", err) + } + return nil +} + +// WriteHooks writes the launcher to .git/hooks/ts-git-hook and a wrapper +// for each name in hooks to .git/hooks/. Stale wrappers from +// prior versions (ours, but no longer in hooks) are removed. If a path +// we are about to write exists and is not one of our wrappers, +// WriteHooks aborts with an error rather than clobber the user's hook. +// Called by the binary's "install" handler (after a rebuild) and by +// Install (initial setup). +func WriteHooks(hooks []string) error { + hookDir, err := findHookDir() + if err != nil { + return err + } + if err := writeLauncher(filepath.Join(hookDir, "ts-git-hook")); err != nil { + return err + } + want := make(map[string]bool, len(hooks)) + for _, h := range hooks { + want[h] = true + } + entries, err := os.ReadDir(hookDir) + if err != nil { + return fmt.Errorf("reading hooks dir: %v", err) + } + for _, e := range entries { + if e.IsDir() { + continue + } + name := e.Name() + path := filepath.Join(hookDir, name) + mine, err := isOurWrapper(path) + if err != nil { + return fmt.Errorf("inspecting %s: %v", path, err) + } + switch { + case want[name] && !mine: + return fmt.Errorf("%s exists and is not a ts-git-hook wrapper; "+ + "move your hook to %s.local (it will be chained after the wrapper) or delete it, then re-run: ./tool/go run ./misc/install-git-hooks.go", + path, name) + case !want[name] && mine: + // Stale wrapper from a prior version (e.g. a hook we used + // to install but no longer do). + if err := os.Remove(path); err != nil { + return fmt.Errorf("removing stale wrapper %s: %v", name, err) + } + } + } + for _, h := range hooks { + content := fmt.Sprintf(wrapperScript, h) + if err := os.WriteFile(filepath.Join(hookDir, h), []byte(content), 0755); err != nil { + return fmt.Errorf("writing wrapper for %s: %v", h, err) + } + } + return nil +} + +// isOurWrapper reports whether path is a hook wrapper written by us +// (in any historical format). Files we will never own (the launcher +// itself, user-chained .local hooks, git's .sample examples) return +// false unconditionally and are not read. An I/O error other than +// "not found" is returned to the caller; a missing file is not an +// error. +func isOurWrapper(path string) (bool, error) { + name := filepath.Base(path) + if name == "ts-git-hook" || + strings.HasSuffix(name, ".local") || + strings.HasSuffix(name, ".sample") { + return false, nil + } + b, err := os.ReadFile(path) + if os.IsNotExist(err) { + return false, nil + } + if err != nil { + return false, err + } + return wrapperRE.Match(b), nil +} + +// writeLauncher writes the embedded launcher to target via atomic rename, +// so a currently-running launcher keeps reading its old inode. +func writeLauncher(target string) error { + dir, name := filepath.Split(target) + f, err := os.CreateTemp(dir, name+".*") + if err != nil { + return fmt.Errorf("creating temp launcher: %v", err) + } + tmp := f.Name() + if _, err := f.Write(Launcher); err != nil { + f.Close() + os.Remove(tmp) + return fmt.Errorf("writing temp launcher: %v", err) + } + if err := f.Close(); err != nil { + os.Remove(tmp) + return err + } + if err := os.Chmod(tmp, 0755); err != nil { + os.Remove(tmp) + return err + } + if err := os.Rename(tmp, target); err != nil { + os.Remove(tmp) + return fmt.Errorf("installing launcher: %v", err) + } + return nil +} + +func findHookDir() (string, error) { + out, err := exec.Command("git", "rev-parse", "--git-path", "hooks").CombinedOutput() + if err != nil { + return "", fmt.Errorf("finding hooks dir: %v, %s", err, out) + } + hookDir, err := filepath.Abs(strings.TrimSpace(string(out))) + if err != nil { + return "", err + } + fi, err := os.Stat(hookDir) + if err != nil { + return "", fmt.Errorf("checking hooks dir: %v", err) + } + if !fi.IsDir() { + return "", fmt.Errorf("%s is not a directory", hookDir) + } + return hookDir, nil +} + +const wrapperScript = `#!/usr/bin/env bash +exec "$(dirname "${BASH_SOURCE[0]}")/ts-git-hook" %s "$@" +` + +// wrapperRE matches every historical shape of wrapperScript: a tiny +// bash script that execs a sibling ts-git-hook with a single hook-name +// argument. The inner quoting of ${BASH_SOURCE[0]} changed between +// versions, hence the "?s. +var wrapperRE = regexp.MustCompile( + `\A#!/usr/bin/env bash\nexec "\$\(dirname "?\$\{BASH_SOURCE\[0\]\}"?\)/ts-git-hook" [\w-]+ "\$@"\n?\z`, +) diff --git a/misc/git_hook/githook/launcher.sh b/misc/git_hook/githook/launcher.sh new file mode 100755 index 0000000000000..eddab585e2dbb --- /dev/null +++ b/misc/git_hook/githook/launcher.sh @@ -0,0 +1,51 @@ +#!/usr/bin/env bash +# ts-git-hook launcher (installed at .git/hooks/ts-git-hook). +# +# Written by misc/install-git-hooks.go from the canonical copy embedded +# in tailscale.com/misc/git_hook/githook. On every invocation it: +# +# 1. Compares the binary's reported version against the shared +# githook HOOK_VERSION (resolved via `go list -m tailscale.com`) +# plus the repo-local HOOK_VERSION. +# 2. If stale or missing: rebuilds ts-git-hook-bin and runs +# `ts-git-hook-bin install`. +# 3. Execs the binary with the hook's args. +set -euo pipefail + +REPO_ROOT="$(git rev-parse --show-toplevel 2>/dev/null)" || { + echo "git-hook: not in a git repo" >&2 + exit 1 +} + +HOOK_DIR="$(git -C "$REPO_ROOT" rev-parse --git-path hooks)" +case "$HOOK_DIR" in +/*) ;; +*) HOOK_DIR="$REPO_ROOT/$HOOK_DIR" ;; +esac + +# Windows (Git for Windows / MSYS2) needs .exe suffixes. +EXE="" +case "$(uname -s)" in MINGW* | MSYS* | CYGWIN*) EXE=".exe" ;; esac + +BINARY="$HOOK_DIR/ts-git-hook-bin$EXE" + +GO="$REPO_ROOT/tool/go$EXE" +if [ ! -x "$GO" ]; then GO=go; fi + +OSS_DIR="$(cd "$REPO_ROOT" && GOWORK=off "$GO" list -m -f '{{.Dir}}' tailscale.com 2>/dev/null || true)" +SHARED_VER="$(cat "$OSS_DIR/misc/git_hook/githook/HOOK_VERSION" 2>/dev/null || echo 0)" +LOCAL_VER="$(cat "$REPO_ROOT/misc/git_hook/HOOK_VERSION" 2>/dev/null || echo 0)" +WANT="$SHARED_VER:$LOCAL_VER" +HAVE="$("$BINARY" version 2>/dev/null || echo none)" + +if [ "$WANT" != "$HAVE" ]; then + echo "git-hook: rebuilding ts-git-hook-bin..." >&2 + (cd "$REPO_ROOT" && GOWORK=off "$GO" build -o "$BINARY" ./misc/git_hook) || { + echo "git-hook: rebuild failed, run: ./tool/go run ./misc/install-git-hooks.go" >&2 + exit 1 + } + "$BINARY" install +fi + +exec "$BINARY" "$@" + diff --git a/misc/git_hook/githook/pre-commit.go b/misc/git_hook/githook/pre-commit.go new file mode 100644 index 0000000000000..30e4f6a9e42c9 --- /dev/null +++ b/misc/git_hook/githook/pre-commit.go @@ -0,0 +1,62 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package githook + +import ( + "bytes" + "errors" + "fmt" + "os/exec" + + "github.com/fatih/color" + "github.com/sourcegraph/go-diff/diff" +) + +var preCommitForbiddenPatterns = [][]byte{ + // Concatenation avoids tripping the check on this file. + []byte("NOCOM" + "MIT"), + []byte("DO NOT " + "SUBMIT"), +} + +// CheckForbiddenMarkers scans the staged diff for forbidden markers +// and returns an error if any are found. +// +// Intended as a pre-commit hook. +// https://git-scm.com/docs/githooks#_pre_commit +func CheckForbiddenMarkers() error { + diffOut, err := exec.Command("git", "diff", "--cached").Output() + if err != nil { + return fmt.Errorf("could not get git diff: %w", err) + } + + diffs, err := diff.ParseMultiFileDiff(diffOut) + if err != nil { + return fmt.Errorf("could not parse diff: %w", err) + } + + foundForbidden := false + for _, d := range diffs { + for _, hunk := range d.Hunks { + lines := bytes.Split(hunk.Body, []byte{'\n'}) + for i, line := range lines { + if len(line) == 0 || line[0] != '+' { + continue + } + for _, forbidden := range preCommitForbiddenPatterns { + if bytes.Contains(line, forbidden) { + if !foundForbidden { + color.New(color.Bold, color.FgRed, color.Underline).Printf("%s found:\n", forbidden) + } + fmt.Printf("%s:%d: %s\n", d.NewName[2:], int(hunk.NewStartLine)+i, line[1:]) + foundForbidden = true + } + } + } + } + } + if foundForbidden { + return errors.New("found forbidden string") + } + return nil +} diff --git a/misc/git_hook/githook/pre-push.go b/misc/git_hook/githook/pre-push.go new file mode 100644 index 0000000000000..9d5624523fe9d --- /dev/null +++ b/misc/git_hook/githook/pre-push.go @@ -0,0 +1,112 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package githook + +import ( + "bufio" + "fmt" + "os" + "os/exec" + "strings" + + "golang.org/x/mod/modfile" +) + +// CheckGoModReplaces reads pushes from stdin and, for pushes to a +// remote URL in watchedRemotes, rejects any commit whose go.mod has a +// directory-path replace that is not in allowedReplaceDirs. args is +// the pre-push hook's argv (remoteName, remoteLoc). +// +// Intended as a pre-push hook. +// https://git-scm.com/docs/githooks#_pre_push +func CheckGoModReplaces(args []string, watchedRemotes, allowedReplaceDirs []string) error { + if len(args) < 2 { + return fmt.Errorf("pre-push: expected 2 args, got %d", len(args)) + } + remoteLoc := args[1] + + watched := false + for _, r := range watchedRemotes { + if r == remoteLoc { + watched = true + break + } + } + if !watched { + return nil + } + + pushes, err := readPushes() + if err != nil { + return fmt.Errorf("reading pushes: %w", err) + } + for _, p := range pushes { + if p.isDoNotMergeRef() { + continue + } + if err := checkCommit(p.localSHA, allowedReplaceDirs); err != nil { + return fmt.Errorf("not allowing push of %v to %v: %v", p.localSHA, p.remoteRef, err) + } + } + return nil +} + +func checkCommit(sha string, allowedReplaceDirs []string) error { + if sha == zeroRef { + // Allow ref deletions. + return nil + } + goMod, err := exec.Command("git", "show", sha+":go.mod").Output() + if err != nil { + return err + } + mf, err := modfile.Parse("go.mod", goMod, nil) + if err != nil { + return fmt.Errorf("failed to parse its go.mod: %v", err) + } + for _, r := range mf.Replace { + if !modfile.IsDirectoryPath(r.New.Path) { + continue + } + allowed := false + for _, a := range allowedReplaceDirs { + if a == r.New.Path { + allowed = true + break + } + } + if !allowed { + return fmt.Errorf("go.mod contains replace from %v => %v", r.Old.Path, r.New.Path) + } + } + return nil +} + +const zeroRef = "0000000000000000000000000000000000000000" + +type push struct { + localRef string + localSHA string + remoteRef string + remoteSHA string +} + +func (p *push) isDoNotMergeRef() bool { + return strings.HasSuffix(p.remoteRef, "/DO-NOT-MERGE") +} + +func readPushes() (pushes []push, err error) { + bs := bufio.NewScanner(os.Stdin) + for bs.Scan() { + f := strings.Fields(bs.Text()) + if len(f) != 4 { + return nil, fmt.Errorf("unexpected push line %q", bs.Text()) + } + pushes = append(pushes, push{f[0], f[1], f[2], f[3]}) + } + if err := bs.Err(); err != nil { + return nil, err + } + return pushes, nil +} diff --git a/misc/install-git-hooks.go b/misc/install-git-hooks.go new file mode 100644 index 0000000000000..813a456016788 --- /dev/null +++ b/misc/install-git-hooks.go @@ -0,0 +1,21 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ignore + +// The install-git-hooks program installs git hooks by delegating to +// githook.Install. See that function's doc for what it does. +package main + +import ( + "log" + + "tailscale.com/misc/git_hook/githook" +) + +func main() { + log.SetFlags(0) + if err := githook.Install(); err != nil { + log.Fatalf("install-git-hooks: %v", err) + } +} diff --git a/net/art/stride_table_test.go b/net/art/stride_table_test.go index e797f40ee0ddc..8279a545e132d 100644 --- a/net/art/stride_table_test.go +++ b/net/art/stride_table_test.go @@ -19,7 +19,7 @@ import ( func TestInversePrefix(t *testing.T) { t.Parallel() for i := range 256 { - for len := 0; len < 9; len++ { + for len := range 9 { addr := i & (0xFF << (8 - len)) idx := prefixIndex(uint8(addr), len) addr2, len2 := inversePrefixIndex(idx) diff --git a/net/batching/conn.go b/net/batching/conn.go index 1631c33cfe448..1843a2cfced5a 100644 --- a/net/batching/conn.go +++ b/net/batching/conn.go @@ -19,14 +19,24 @@ var ( _ ipv6.Message = ipv4.Message{} ) -// Conn is a nettype.PacketConn that provides batched i/o using +// Conn is a [nettype.PacketConn] that provides batched i/o using // platform-specific optimizations, e.g. {recv,send}mmsg & UDP GSO/GRO. // +// Conn does not support single packet reads (see ReadFromUDPAddrPort docs). It +// is the caller's responsibility to use the appropriate read API where a +// [nettype.PacketConn] has been upgraded to support batched i/o. +// // Conn originated from (and is still used by) magicsock where its API was // strongly influenced by [wireguard-go/conn.Bind] constraints, namely // wireguard-go's ownership of packet memory. type Conn interface { nettype.PacketConn + // ReadFromUDPAddrPort always returns an error, as UDP GRO is incompatible + // with single packet reads. A single datagram may be multiple, coalesced + // datagrams, and this API lacks the ability to pass that context. + // + // TODO: consider detaching Conn from [nettype.PacketConn] + ReadFromUDPAddrPort([]byte) (int, netip.AddrPort, error) // ReadBatch reads messages from [Conn] into msgs. It returns the number of // messages the caller should evaluate for nonzero len, as a zero len // message may fall on either side of a nonzero. diff --git a/net/batching/conn_default.go b/net/batching/conn_default.go index 0d208578b6d06..77c4c8b6a8adc 100644 --- a/net/batching/conn_default.go +++ b/net/batching/conn_default.go @@ -10,7 +10,7 @@ import ( ) // TryUpgradeToConn is no-op on all platforms except linux. -func TryUpgradeToConn(pconn nettype.PacketConn, _ string, _ int) nettype.PacketConn { +func TryUpgradeToConn(pconn nettype.PacketConn, _ string, _ int, _ string) nettype.PacketConn { return pconn } diff --git a/net/batching/conn_linux.go b/net/batching/conn_linux.go index 373625b772738..ea11f439a88bd 100644 --- a/net/batching/conn_linux.go +++ b/net/batching/conn_linux.go @@ -24,6 +24,7 @@ import ( "tailscale.com/net/neterror" "tailscale.com/net/packet" "tailscale.com/types/nettype" + "tailscale.com/util/clientmetric" ) // xnetBatchReaderWriter defines the batching i/o methods of @@ -51,26 +52,23 @@ var ( // linuxBatchingConn is a UDP socket that provides batched i/o. It implements // [Conn]. type linuxBatchingConn struct { - pc *net.UDPConn - xpc xnetBatchReaderWriter - rxOffload bool // supports UDP GRO or similar - txOffload atomic.Bool // supports UDP GSO or similar - setGSOSizeInControl func(control *[]byte, gsoSize uint16) // typically setGSOSizeInControl(); swappable for testing - getGSOSizeFromControl func(control []byte) (int, error) // typically getGSOSizeFromControl(); swappable for testing - sendBatchPool sync.Pool + pc *net.UDPConn + xpc xnetBatchReaderWriter + rxOffload bool // supports UDP GRO or similar + txOffload atomic.Bool // supports UDP GSO or similar + sendBatchPool sync.Pool + rxqOverflowsMetric *clientmetric.Metric + + // readOpMu guards read operations that must perform accounting against + // rxqOverflows in single-threaded fashion. There are no concurrent usages + // of read operations at the time of writing (2026-03-09), but it would be + // unidiomatic to push this responsibility onto callers. + readOpMu sync.Mutex + rxqOverflows uint32 // kernel pumps a cumulative counter, which we track to push a clientmetric delta value } func (c *linuxBatchingConn) ReadFromUDPAddrPort(p []byte) (n int, addr netip.AddrPort, err error) { - if c.rxOffload { - // UDP_GRO is opt-in on Linux via setsockopt(). Once enabled you may - // receive a "monster datagram" from any read call. The ReadFrom() API - // does not support passing the GSO size and is unsafe to use in such a - // case. Other platforms may vary in behavior, but we go with the most - // conservative approach to prevent this from becoming a footgun in the - // future. - return 0, netip.AddrPort{}, errors.New("rx UDP offload is enabled on this socket, single packet reads are unavailable") - } - return c.pc.ReadFromUDPAddrPort(p) + return 0, netip.AddrPort{}, errors.New("single packet reads are unsupported") } func (c *linuxBatchingConn) SetDeadline(t time.Time) error { @@ -89,6 +87,15 @@ const ( // This was initially established for Linux, but may split out to // GOOS-specific values later. It originates as UDP_MAX_SEGMENTS in the // kernel's TX path, and UDP_GRO_CNT_MAX for RX. + // + // As long as we use one fragment per datagram, this also serves as a + // limit for the number of fragments we can coalesce during scatter-gather writes. + // + // 64 is below the 1024 of IOV_MAX (Linux) or UIO_MAXIOV (BSD), + // and the 256 of WSABUF_MAX_COUNT (Windows). + // + // (2026-04) If we begin shipping datagrams in more than one fragment, + // an independent fragment count limit needs to be implemented. udpSegmentMaxDatagrams = 64 ) @@ -101,15 +108,24 @@ const ( // coalesceMessages iterates 'buffs', setting and coalescing them in 'msgs' // where possible while maintaining datagram order. // +// It aggregates message components as a list of buffers without copying, +// and expects to be used only on Linux with scatter-gather writes via sendmmsg(2). +// +// All msgs[i].Buffers len must be one. Will panic if there is not enough msgs +// to coalesce all buffs. +// // All msgs have their Addr field set to addr. // // All msgs[i].Buffers[0] are preceded by a Geneve header (geneve) if geneve.VNI.IsSet(). +// +// TODO(illotum) explore MSG_ZEROCOPY for large writes (>10KB). func (c *linuxBatchingConn) coalesceMessages(addr *net.UDPAddr, geneve packet.GeneveHeader, buffs [][]byte, msgs []ipv6.Message, offset int) int { var ( - base = -1 // index of msg we are currently coalescing into - gsoSize int // segmentation size of msgs[base] - dgramCnt int // number of dgrams coalesced into msgs[base] - endBatch bool // tracking flag to start a new batch on next iteration of buffs + base = -1 // index of msg we are currently coalescing into + gsoSize int // segmentation size of msgs[base] + dgramCnt int // number of dgrams coalesced into msgs[base] + endBatch bool // tracking flag to start a new batch on next iteration of buffs + coalescedLen int // bytes coalesced into msgs[base] ) maxPayloadLen := maxIPv4PayloadLen if addr.IP.To4() == nil { @@ -124,19 +140,18 @@ func (c *linuxBatchingConn) coalesceMessages(addr *net.UDPAddr, geneve packet.Ge } if i > 0 { msgLen := len(buff) - baseLenBefore := len(msgs[base].Buffers[0]) - freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore - if msgLen+baseLenBefore <= maxPayloadLen && + if msgLen+coalescedLen <= maxPayloadLen && msgLen <= gsoSize && - msgLen <= freeBaseCap && dgramCnt < udpSegmentMaxDatagrams && !endBatch { - msgs[base].Buffers[0] = append(msgs[base].Buffers[0], make([]byte, msgLen)...) - copy(msgs[base].Buffers[0][baseLenBefore:], buff) + // msgs[base].Buffers[0] is set to buff[i] when a new base is set. + // This appends a struct iovec element in the underlying struct msghdr (scatter-gather). + msgs[base].Buffers = append(msgs[base].Buffers, buff) if i == len(buffs)-1 { - c.setGSOSizeInControl(&msgs[base].OOB, uint16(gsoSize)) + setGSOSizeInControl(&msgs[base].OOB, uint16(gsoSize)) } dgramCnt++ + coalescedLen += msgLen if msgLen < gsoSize { // A smaller than gsoSize packet on the tail is legal, but // it must end the batch. @@ -146,7 +161,7 @@ func (c *linuxBatchingConn) coalesceMessages(addr *net.UDPAddr, geneve packet.Ge } } if dgramCnt > 1 { - c.setGSOSizeInControl(&msgs[base].OOB, uint16(gsoSize)) + setGSOSizeInControl(&msgs[base].OOB, uint16(gsoSize)) } // Reset prior to incrementing base since we are preparing to start a // new potential batch. @@ -157,6 +172,7 @@ func (c *linuxBatchingConn) coalesceMessages(addr *net.UDPAddr, geneve packet.Ge msgs[base].Buffers[0] = buff msgs[base].Addr = addr dgramCnt = 1 + coalescedLen = len(buff) } return base + 1 } @@ -173,7 +189,10 @@ func (c *linuxBatchingConn) getSendBatch() *sendBatch { func (c *linuxBatchingConn) putSendBatch(batch *sendBatch) { for i := range batch.msgs { - batch.msgs[i] = ipv6.Message{Buffers: batch.msgs[i].Buffers, OOB: batch.msgs[i].OOB} + // Non coalesced write paths access only batch.msgs[i].Buffers[0], + // but we append more during [linuxBatchingConn.coalesceMessages]. + // Leave index zero accessible: + batch.msgs[i] = ipv6.Message{Buffers: batch.msgs[i].Buffers[:1], OOB: batch.msgs[i].OOB} } c.sendBatchPool.Put(batch) } @@ -262,7 +281,7 @@ func (c *linuxBatchingConn) splitCoalescedMessages(msgs []ipv6.Message, firstMsg end = msg.N numToSplit = 1 ) - gsoSize, err = c.getGSOSizeFromControl(msg.OOB[:msg.NN]) + gsoSize, err = getGSOSizeFromControl(msg.OOB[:msg.NN]) if err != nil { return n, err } @@ -294,16 +313,87 @@ func (c *linuxBatchingConn) splitCoalescedMessages(msgs []ipv6.Message, firstMsg return n, nil } +// getDataFromControl returns the data portion of the first control msg with +// matching cmsgLevel, matching cmsgType, and min data len of minDataLen, in +// control. If no matching cmsg is found or the len(control) < unix.SizeofCmsghdr, +// this function returns nil data. A non-nil error will be returned if +// len(control) > unix.SizeofCmsghdr but its contents cannot be parsed as a +// socket control message. +func getDataFromControl(control []byte, cmsgLevel, cmsgType int32, minDataLen int) ([]byte, error) { + var ( + hdr unix.Cmsghdr + data []byte + rem = control + err error + ) + + for len(rem) > unix.SizeofCmsghdr { + hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem) + if err != nil { + return nil, fmt.Errorf("error parsing socket control message: %w", err) + } + if hdr.Level == cmsgLevel && hdr.Type == cmsgType && len(data) >= minDataLen { + return data, nil + } + } + return nil, nil +} + +// getRXQOverflowsFromControl returns the rxq overflows cumulative counter found +// in control. If no rxq counter is found or the len(control) < unix.SizeofCmsghdr, +// this function returns 0. A non-nil error will be returned if control is +// malformed. +func getRXQOverflowsFromControl(control []byte) (uint32, error) { + data, err := getDataFromControl(control, unix.SOL_SOCKET, unix.SO_RXQ_OVFL, 4) + if err != nil { + return 0, err + } + if len(data) >= 4 { + return binary.NativeEndian.Uint32(data), nil + } + return 0, nil +} + +// handleRXQOverflowCounter handles any rx queue overflow counter contained in +// the tail of msgs. +func (c *linuxBatchingConn) handleRXQOverflowCounter(msgs []ipv6.Message, n int, rxErr error) { + if n == 0 || rxErr != nil || c.rxqOverflowsMetric == nil { + return + } + tailMsg := msgs[n-1] // we only care about the latest value as it's a cumulative counter + if tailMsg.NN == 0 { + return + } + rxqOverflows, err := getRXQOverflowsFromControl(tailMsg.OOB[:tailMsg.NN]) + if err != nil { + return + } + // The counter is always present once nonzero on the kernel side. Compare it + // with our previous view, push the delta to the clientmetric, and update + // our view. + if rxqOverflows == c.rxqOverflows { + return + } + delta := int64(rxqOverflows - c.rxqOverflows) + c.rxqOverflowsMetric.Add(delta) + c.rxqOverflows = rxqOverflows +} + func (c *linuxBatchingConn) ReadBatch(msgs []ipv6.Message, flags int) (n int, err error) { + c.readOpMu.Lock() + defer c.readOpMu.Unlock() if !c.rxOffload || len(msgs) < 2 { - return c.xpc.ReadBatch(msgs, flags) + n, err = c.xpc.ReadBatch(msgs, flags) + c.handleRXQOverflowCounter(msgs, n, err) + return n, err } // Read into the tail of msgs, split into the head. readAt := len(msgs) - 2 - numRead, err := c.xpc.ReadBatch(msgs[readAt:], 0) - if err != nil || numRead == 0 { + n, err = c.xpc.ReadBatch(msgs[readAt:], 0) + if err != nil || n == 0 { return 0, err } + c.handleRXQOverflowCounter(msgs[readAt:], n, err) return c.splitCoalescedMessages(msgs, readAt) } @@ -319,6 +409,21 @@ func (c *linuxBatchingConn) Close() error { return c.pc.Close() } +// tryEnableRXQOverflowsCounter attempts to enable the SO_RXQ_OVFL socket option +// on pconn, and returns the result. SO_RXQ_OVFL was added in Linux v2.6.33. +func tryEnableRXQOverflowsCounter(pconn nettype.PacketConn) (enabled bool) { + if c, ok := pconn.(*net.UDPConn); ok { + rc, err := c.SyscallConn() + if err != nil { + return + } + rc.Control(func(fd uintptr) { + enabled = syscall.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RXQ_OVFL, 1) == nil + }) + } + return enabled +} + // tryEnableUDPOffload attempts to enable the UDP_GRO socket option on pconn, // and returns two booleans indicating TX and RX UDP offload support. func tryEnableUDPOffload(pconn nettype.PacketConn) (hasTX bool, hasRX bool) { @@ -340,32 +445,25 @@ func tryEnableUDPOffload(pconn nettype.PacketConn) (hasTX bool, hasRX bool) { return hasTX, hasRX } -// getGSOSizeFromControl returns the GSO size found in control. If no GSO size -// is found or the len(control) < unix.SizeofCmsghdr, this function returns 0. -// A non-nil error will be returned if len(control) > unix.SizeofCmsghdr but -// its contents cannot be parsed as a socket control message. +// getGSOSizeFromControl returns the GSO size found in control associated with a +// cmsg type of UDP_GRO, which the kernel populates in the read direction. If no +// GSO size is found or the len(control) < unix.SizeofCmsghdr, this function +// returns 0. A non-nil error will be returned if control is malformed. func getGSOSizeFromControl(control []byte) (int, error) { - var ( - hdr unix.Cmsghdr - data []byte - rem = control - err error - ) - - for len(rem) > unix.SizeofCmsghdr { - hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem) - if err != nil { - return 0, fmt.Errorf("error parsing socket control message: %w", err) - } - if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= 2 { - return int(binary.NativeEndian.Uint16(data[:2])), nil - } + data, err := getDataFromControl(control, unix.SOL_UDP, unix.UDP_GRO, 2) + if err != nil { + return 0, err + } + if len(data) >= 2 { + return int(binary.NativeEndian.Uint16(data)), nil } return 0, nil } // setGSOSizeInControl sets a socket control message in control containing -// gsoSize. If len(control) < controlMessageSize control's len will be set to 0. +// gsoSize with an associated cmsg type of UDP_SEGMENT, which we are responsible +// for populating prior to writing towards the kernel. If len(control) < controlMessageSize +// control's len will be set to 0. func setGSOSizeInControl(control *[]byte, gsoSize uint16) { *control = (*control)[:0] if cap(*control) < int(unsafe.Sizeof(unix.Cmsghdr{})) { @@ -383,10 +481,39 @@ func setGSOSizeInControl(control *[]byte, gsoSize uint16) { *control = (*control)[:unix.CmsgSpace(2)] } +var ( + rxqOverflowsMetricsMu sync.Mutex + rxqOverflowsMetricsByName map[string]*clientmetric.Metric +) + +// getRXQOverflowsMetric returns a counter-based [*clientmetric.Metric] for the +// provided name in a thread-safe manner. Callers may pass the same metric name +// multiple times, which is common across rebinds of the underlying, associated +// [Conn]. +func getRXQOverflowsMetric(name string) *clientmetric.Metric { + if len(name) == 0 { + return nil + } + rxqOverflowsMetricsMu.Lock() + defer rxqOverflowsMetricsMu.Unlock() + m, ok := rxqOverflowsMetricsByName[name] + if ok { + return m + } + if rxqOverflowsMetricsByName == nil { + rxqOverflowsMetricsByName = make(map[string]*clientmetric.Metric) + } + m = clientmetric.NewCounter(name) + rxqOverflowsMetricsByName[name] = m + return m +} + // TryUpgradeToConn probes the capabilities of the OS and pconn, and upgrades // pconn to a [Conn] if appropriate. A batch size of [IdealBatchSize] is -// suggested for the best performance. -func TryUpgradeToConn(pconn nettype.PacketConn, network string, batchSize int) nettype.PacketConn { +// suggested for the best performance. If len(rxqOverflowsMetricName) is +// nonzero, then read ops will propagate the SO_RXQ_OVFL control message counter +// to a clientmetric with the supplied name. +func TryUpgradeToConn(pconn nettype.PacketConn, network string, batchSize int, rxqOverflowsMetricName string) nettype.PacketConn { if runtime.GOOS != "linux" { // Exclude Android. return pconn @@ -408,9 +535,7 @@ func TryUpgradeToConn(pconn nettype.PacketConn, network string, batchSize int) n return pconn } b := &linuxBatchingConn{ - pc: uc, - getGSOSizeFromControl: getGSOSizeFromControl, - setGSOSizeInControl: setGSOSizeInControl, + pc: uc, sendBatchPool: sync.Pool{ New: func() any { ua := &net.UDPAddr{ @@ -440,15 +565,23 @@ func TryUpgradeToConn(pconn nettype.PacketConn, network string, batchSize int) n var txOffload bool txOffload, b.rxOffload = tryEnableUDPOffload(uc) b.txOffload.Store(txOffload) + if len(rxqOverflowsMetricName) > 0 && tryEnableRXQOverflowsCounter(uc) { + // Don't register the metric unless the socket option has been + // successfully set, otherwise we will report a misleading zero value + // counter on the wire. This is one reason why we prefer to handle + // clientmetric instantiation internally, vs letting callers pass them + // to TryUpgradeToConn. + b.rxqOverflowsMetric = getRXQOverflowsMetric(rxqOverflowsMetricName) + } return b } var controlMessageSize = -1 // bomb if used for allocation before init func init() { - // controlMessageSize is set to hold a UDP_GRO or UDP_SEGMENT control - // message. These contain a single uint16 of data. - controlMessageSize = unix.CmsgSpace(2) + controlMessageSize = + unix.CmsgSpace(2) + // UDP_GRO or UDP_SEGMENT gsoSize (uint16) + unix.CmsgSpace(4) // SO_RXQ_OVFL counter (uint32) } // MinControlMessageSize returns the minimum control message size required to diff --git a/net/batching/conn_linux_test.go b/net/batching/conn_linux_test.go index a15de4f671ec6..fa4eef33c5820 100644 --- a/net/batching/conn_linux_test.go +++ b/net/batching/conn_linux_test.go @@ -5,43 +5,30 @@ package batching import ( "encoding/binary" + "io" + "math" "net" "testing" "unsafe" + qt "github.com/frankban/quicktest" "github.com/tailscale/wireguard-go/conn" "golang.org/x/net/ipv6" "golang.org/x/sys/unix" "tailscale.com/net/packet" ) -func setGSOSize(control *[]byte, gsoSize uint16) { - *control = (*control)[:cap(*control)] - binary.LittleEndian.PutUint16(*control, gsoSize) -} - -func getGSOSize(control []byte) (int, error) { - if len(control) < 2 { - return 0, nil - } - return int(binary.LittleEndian.Uint16(control)), nil -} - func Test_linuxBatchingConn_splitCoalescedMessages(t *testing.T) { - c := &linuxBatchingConn{ - setGSOSizeInControl: setGSOSize, - getGSOSizeFromControl: getGSOSize, - } + c := &linuxBatchingConn{} - newMsg := func(n, gso int) ipv6.Message { + newMsg := func(n int, gso uint16) ipv6.Message { msg := ipv6.Message{ Buffers: [][]byte{make([]byte, 1024)}, N: n, - OOB: make([]byte, 2), + OOB: gsoControl(gso), } - binary.LittleEndian.PutUint16(msg.OOB, uint16(gso)) if gso > 0 { - msg.NN = 2 + msg.NN = len(msg.OOB) } return msg } @@ -55,7 +42,7 @@ func Test_linuxBatchingConn_splitCoalescedMessages(t *testing.T) { wantErr bool }{ { - name: "second last split last empty", + name: "second-last-split-last-empty", msgs: []ipv6.Message{ newMsg(0, 0), newMsg(0, 0), @@ -68,7 +55,7 @@ func Test_linuxBatchingConn_splitCoalescedMessages(t *testing.T) { wantErr: false, }, { - name: "second last no split last empty", + name: "second-last-no-split-last-empty", msgs: []ipv6.Message{ newMsg(0, 0), newMsg(0, 0), @@ -81,7 +68,7 @@ func Test_linuxBatchingConn_splitCoalescedMessages(t *testing.T) { wantErr: false, }, { - name: "second last no split last no split", + name: "second-last-no-split-last-no-split", msgs: []ipv6.Message{ newMsg(0, 0), newMsg(0, 0), @@ -94,7 +81,7 @@ func Test_linuxBatchingConn_splitCoalescedMessages(t *testing.T) { wantErr: false, }, { - name: "second last no split last split", + name: "second-last-no-split-last-split", msgs: []ipv6.Message{ newMsg(0, 0), newMsg(0, 0), @@ -107,7 +94,7 @@ func Test_linuxBatchingConn_splitCoalescedMessages(t *testing.T) { wantErr: false, }, { - name: "second last split last split", + name: "second-last-split-last-split", msgs: []ipv6.Message{ newMsg(0, 0), newMsg(0, 0), @@ -120,7 +107,7 @@ func Test_linuxBatchingConn_splitCoalescedMessages(t *testing.T) { wantErr: false, }, { - name: "second last no split last split overflow", + name: "second-last-no-split-last-split-overflow", msgs: []ipv6.Message{ newMsg(0, 0), newMsg(0, 0), @@ -153,10 +140,7 @@ func Test_linuxBatchingConn_splitCoalescedMessages(t *testing.T) { } func Test_linuxBatchingConn_coalesceMessages(t *testing.T) { - c := &linuxBatchingConn{ - setGSOSizeInControl: setGSOSize, - getGSOSizeFromControl: getGSOSize, - } + c := &linuxBatchingConn{} withGeneveSpace := func(len, cap int) []byte { return make([]byte, len+packet.GeneveFixedHeaderLength, cap+packet.GeneveFixedHeaderLength) @@ -168,108 +152,110 @@ func Test_linuxBatchingConn_coalesceMessages(t *testing.T) { geneve.VNI.Set(1) cases := []struct { - name string - buffs [][]byte - geneve packet.GeneveHeader - wantLens []int + name string + buffs [][]byte + geneve packet.GeneveHeader + // Each wantLens slice corresponds to the Buffers of a single coalesced message, + // and each int is the expected length of the corresponding Buffer[i]. + wantLens [][]int wantGSO []int }{ { - name: "one message no coalesce", + name: "one-message-no-coalesce", buffs: [][]byte{ withGeneveSpace(1, 1), }, - wantLens: []int{1}, + wantLens: [][]int{{1}}, wantGSO: []int{0}, }, { - name: "one message no coalesce vni.isSet", + name: "one-message-no-coalesce-vni-isSet", buffs: [][]byte{ withGeneveSpace(1, 1), }, geneve: geneve, - wantLens: []int{1 + packet.GeneveFixedHeaderLength}, + wantLens: [][]int{{1 + packet.GeneveFixedHeaderLength}}, wantGSO: []int{0}, }, { - name: "two messages equal len coalesce", + name: "two-messages-equal-len-coalesce", buffs: [][]byte{ withGeneveSpace(1, 2), withGeneveSpace(1, 1), }, - wantLens: []int{2}, + wantLens: [][]int{{1, 1}}, wantGSO: []int{1}, }, { - name: "two messages equal len coalesce vni.isSet", + name: "two-messages-equal-len-coalesce-vni-isSet", buffs: [][]byte{ withGeneveSpace(1, 2+packet.GeneveFixedHeaderLength), withGeneveSpace(1, 1), }, geneve: geneve, - wantLens: []int{2 + (2 * packet.GeneveFixedHeaderLength)}, + wantLens: [][]int{{1 + packet.GeneveFixedHeaderLength, 1 + packet.GeneveFixedHeaderLength}}, wantGSO: []int{1 + packet.GeneveFixedHeaderLength}, }, { - name: "two messages unequal len coalesce", + name: "two-messages-unequal-len-coalesce", buffs: [][]byte{ withGeneveSpace(2, 3), withGeneveSpace(1, 1), }, - wantLens: []int{3}, + wantLens: [][]int{{2, 1}}, wantGSO: []int{2}, }, { - name: "two messages unequal len coalesce vni.isSet", + name: "two-messages-unequal-len-coalesce-vni-isSet", buffs: [][]byte{ withGeneveSpace(2, 3+packet.GeneveFixedHeaderLength), withGeneveSpace(1, 1), }, geneve: geneve, - wantLens: []int{3 + (2 * packet.GeneveFixedHeaderLength)}, + wantLens: [][]int{{2 + packet.GeneveFixedHeaderLength, 1 + packet.GeneveFixedHeaderLength}}, wantGSO: []int{2 + packet.GeneveFixedHeaderLength}, }, { - name: "three messages second unequal len coalesce", + name: "three-messages-second-unequal-len-coalesce", buffs: [][]byte{ withGeneveSpace(2, 3), withGeneveSpace(1, 1), withGeneveSpace(2, 2), }, - wantLens: []int{3, 2}, + wantLens: [][]int{{2, 1}, {2}}, wantGSO: []int{2, 0}, }, { - name: "three messages second unequal len coalesce vni.isSet", + name: "three-messages-second-unequal-len-coalesce-vni-isSet", buffs: [][]byte{ withGeneveSpace(2, 3+(2*packet.GeneveFixedHeaderLength)), withGeneveSpace(1, 1), withGeneveSpace(2, 2), }, geneve: geneve, - wantLens: []int{3 + (2 * packet.GeneveFixedHeaderLength), 2 + packet.GeneveFixedHeaderLength}, + wantLens: [][]int{{2 + packet.GeneveFixedHeaderLength, 1 + packet.GeneveFixedHeaderLength}, {2 + packet.GeneveFixedHeaderLength}}, wantGSO: []int{2 + packet.GeneveFixedHeaderLength, 0}, }, { - name: "three messages limited cap coalesce", + name: "three-messages-limited-cap-coalesce", buffs: [][]byte{ withGeneveSpace(2, 4), withGeneveSpace(2, 2), withGeneveSpace(2, 2), }, - wantLens: []int{4, 2}, - wantGSO: []int{2, 0}, + wantLens: [][]int{{2, 2, 2}}, + wantGSO: []int{2}, }, { - name: "three messages limited cap coalesce vni.isSet", + name: "three-messages-limited-cap-coalesce-vni-isSet", buffs: [][]byte{ withGeneveSpace(2, 4+packet.GeneveFixedHeaderLength), withGeneveSpace(2, 2), withGeneveSpace(2, 2), }, geneve: geneve, - wantLens: []int{4 + (2 * packet.GeneveFixedHeaderLength), 2 + packet.GeneveFixedHeaderLength}, - wantGSO: []int{2 + packet.GeneveFixedHeaderLength, 0}, + wantLens: [][]int{{2 + packet.GeneveFixedHeaderLength, 2 + packet.GeneveFixedHeaderLength, 2 + packet.GeneveFixedHeaderLength}}, + wantGSO: []int{2 + packet.GeneveFixedHeaderLength}, }, } @@ -282,7 +268,7 @@ func Test_linuxBatchingConn_coalesceMessages(t *testing.T) { msgs := make([]ipv6.Message, len(tt.buffs)) for i := range msgs { msgs[i].Buffers = make([][]byte, 1) - msgs[i].OOB = make([]byte, 0, 2) + msgs[i].OOB = make([]byte, controlMessageSize) } got := c.coalesceMessages(addr, tt.geneve, tt.buffs, msgs, packet.GeneveFixedHeaderLength) if got != len(tt.wantLens) { @@ -292,13 +278,28 @@ func Test_linuxBatchingConn_coalesceMessages(t *testing.T) { if msgs[i].Addr != addr { t.Errorf("msgs[%d].Addr != passed addr", i) } - gotLen := len(msgs[i].Buffers[0]) - if gotLen != tt.wantLens[i] { - t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i]) + if len(msgs[i].Buffers) != len(tt.wantLens[i]) { + t.Fatalf("len(msgs[%d].Buffers) %d != %d", i, len(msgs[i].Buffers), len(tt.wantLens[i])) + } + for j := range tt.wantLens[i] { + gotLen := len(msgs[i].Buffers[j]) + if gotLen != tt.wantLens[i][j] { + t.Errorf("len(msgs[%d].Buffers[%d]) %d != %d", i, j, gotLen, tt.wantLens[i][j]) + } } - gotGSO, err := getGSOSize(msgs[i].OOB) + + // coalesceMessages calls setGSOSizeInControl, which uses a cmsg + // type of UDP_SEGMENT, and getGSOSizeInControl scans for a cmsg + // type of UDP_GRO. Therefore, we have to use the lower-level + // getDataFromControl in order to specify the cmsg type of + // interest for this test. + data, err := getDataFromControl(msgs[i].OOB, unix.SOL_UDP, unix.UDP_SEGMENT, 2) if err != nil { - t.Fatalf("msgs[%d] getGSOSize err: %v", i, err) + t.Fatalf("msgs[%d] getDataFromControl err: %v", i, err) + } + var gotGSO int + if len(data) >= 2 { + gotGSO = int(binary.NativeEndian.Uint16(data)) } if gotGSO != tt.wantGSO[i] { t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i]) @@ -317,35 +318,187 @@ func TestMinReadBatchMsgsLen(t *testing.T) { } } -func Test_getGSOSizeFromControl_MultipleMessages(t *testing.T) { - // Test that getGSOSizeFromControl correctly parses UDP_GRO when it's not the first control message. - const expectedGSOSize = 1420 +func makeControlMsg(cmsgLevel, cmsgType int32, dataLen int) []byte { + msgLen := unix.CmsgSpace(dataLen) + msg := make([]byte, msgLen) + hdr2 := (*unix.Cmsghdr)(unsafe.Pointer(&msg[0])) + hdr2.Level = cmsgLevel + hdr2.Type = cmsgType + hdr2.SetLen(unix.CmsgLen(dataLen)) + return msg +} - // First message: IP_TOS - firstMsgLen := unix.CmsgSpace(1) - firstMsg := make([]byte, firstMsgLen) - hdr1 := (*unix.Cmsghdr)(unsafe.Pointer(&firstMsg[0])) - hdr1.Level = unix.SOL_IP - hdr1.Type = unix.IP_TOS - hdr1.SetLen(unix.CmsgLen(1)) - firstMsg[unix.SizeofCmsghdr] = 0 +func gsoControl(gso uint16) []byte { + msg := makeControlMsg(unix.SOL_UDP, unix.UDP_GRO, 2) + binary.NativeEndian.PutUint16(msg[unix.SizeofCmsghdr:], gso) + return msg +} - // Second message: UDP_GRO - secondMsgLen := unix.CmsgSpace(2) - secondMsg := make([]byte, secondMsgLen) - hdr2 := (*unix.Cmsghdr)(unsafe.Pointer(&secondMsg[0])) - hdr2.Level = unix.SOL_UDP - hdr2.Type = unix.UDP_GRO - hdr2.SetLen(unix.CmsgLen(2)) - binary.NativeEndian.PutUint16(secondMsg[unix.SizeofCmsghdr:], expectedGSOSize) +func rxqOverflowsControl(count uint32) []byte { + msg := makeControlMsg(unix.SOL_SOCKET, unix.SO_RXQ_OVFL, 4) + binary.NativeEndian.PutUint32(msg[unix.SizeofCmsghdr:], count) + return msg +} + +func Test_getRXQOverflowsMetric(t *testing.T) { + c := qt.New(t) + m := getRXQOverflowsMetric("") + c.Assert(m, qt.IsNil) + m = getRXQOverflowsMetric("rxq_overflows") + c.Assert(m, qt.IsNotNil) + wantM := getRXQOverflowsMetric("rxq_overflows") + c.Assert(m, qt.Equals, wantM) + uniq := getRXQOverflowsMetric("rxq_overflows_uniq") + c.Assert(m, qt.Not(qt.Equals), uniq) +} + +func Test_getRXQOverflowsFromControl(t *testing.T) { + malformedControlMsg := gsoControl(1) + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&malformedControlMsg[0])) + hdr.SetLen(1) + + tests := []struct { + name string + control []byte + want uint32 + wantErr bool + }{ + { + name: "malformed", + control: malformedControlMsg, + want: 0, + wantErr: true, + }, + { + name: "gso", + control: gsoControl(1), + want: 0, + wantErr: false, + }, + { + name: "rxq-overflows", + control: rxqOverflowsControl(1), + want: 1, + wantErr: false, + }, + { + name: "multiple-cmsg-rxq-overflows-at-head", + control: append(rxqOverflowsControl(1), gsoControl(1)...), + want: 1, + wantErr: false, + }, + { + name: "multiple-cmsg-rxq-overflows-at-tail", + control: append(gsoControl(1), rxqOverflowsControl(1)...), + want: 1, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := getRXQOverflowsFromControl(tt.control) + if (err != nil) != tt.wantErr { + t.Errorf("getRXQOverflowsFromControl() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("getRXQOverflowsFromControl() got = %v, want %v", got, tt.want) + } + }) + } +} - control := append(firstMsg, secondMsg...) +func Test_getGSOSizeFromControl(t *testing.T) { + malformedControlMsg := gsoControl(1) + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&malformedControlMsg[0])) + hdr.SetLen(1) - gsoSize, err := getGSOSizeFromControl(control) - if err != nil { - t.Fatalf("unexpected error: %v", err) + tests := []struct { + name string + control []byte + want int + wantErr bool + }{ + { + name: "malformed", + control: malformedControlMsg, + want: 0, + wantErr: true, + }, + { + name: "gso", + control: gsoControl(1), + want: 1, + wantErr: false, + }, + { + name: "rxq-overflows", + control: rxqOverflowsControl(1), + want: 0, + wantErr: false, + }, + { + name: "multiple-cmsg-gso-at-tail", + control: append(rxqOverflowsControl(1), gsoControl(1)...), + want: 1, + wantErr: false, + }, + { + name: "multiple-cmsg-gso-at-head", + control: append(gsoControl(1), rxqOverflowsControl(1)...), + want: 1, + wantErr: false, + }, } - if gsoSize != expectedGSOSize { - t.Errorf("got GSO size %d, want %d", gsoSize, expectedGSOSize) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := getGSOSizeFromControl(tt.control) + if (err != nil) != tt.wantErr { + t.Errorf("getGSOSizeFromControl() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("getGSOSizeFromControl() got = %v, want %v", got, tt.want) + } + }) } } + +func Test_linuxBatchingConn_handleRXQOverflowCounter(t *testing.T) { + c := qt.New(t) + conn := &linuxBatchingConn{ + rxqOverflowsMetric: getRXQOverflowsMetric("test_handleRXQOverflowCounter"), + } + conn.rxqOverflowsMetric.Set(0) // test count > 1 will accumulate, reset + + // n == 0 + conn.handleRXQOverflowCounter([]ipv6.Message{{}}, 0, nil) + c.Assert(conn.rxqOverflowsMetric.Value(), qt.Equals, int64(0)) + + // rxErr non-nil + conn.handleRXQOverflowCounter([]ipv6.Message{{}}, 0, io.EOF) + c.Assert(conn.rxqOverflowsMetric.Value(), qt.Equals, int64(0)) + + // nonzero counter + control := rxqOverflowsControl(1) + conn.handleRXQOverflowCounter([]ipv6.Message{{ + OOB: control, + NN: len(control), + }}, 1, nil) + c.Assert(conn.rxqOverflowsMetric.Value(), qt.Equals, int64(1)) + + // nonzero counter, no change + conn.handleRXQOverflowCounter([]ipv6.Message{{ + OOB: control, + NN: len(control), + }}, 1, nil) + c.Assert(conn.rxqOverflowsMetric.Value(), qt.Equals, int64(1)) + + // counter rollover + control = rxqOverflowsControl(0) + conn.handleRXQOverflowCounter([]ipv6.Message{{ + OOB: control, + NN: len(control), + }}, 1, nil) + c.Assert(conn.rxqOverflowsMetric.Value(), qt.Equals, int64(1+math.MaxUint32)) +} diff --git a/net/captivedetection/captivedetection_test.go b/net/captivedetection/captivedetection_test.go index 2aa660d88d0a4..6b09ca0cc9672 100644 --- a/net/captivedetection/captivedetection_test.go +++ b/net/captivedetection/captivedetection_test.go @@ -94,8 +94,7 @@ func TestCaptivePortalRequest(t *testing.T) { now := time.Now() d.clock = func() time.Time { return now } - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != "GET" { @@ -133,8 +132,7 @@ func TestCaptivePortalRequest(t *testing.T) { func TestAgainstDERPHandler(t *testing.T) { d := NewDetector(t.Logf) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() s := httptest.NewServer(http.HandlerFunc(derpserver.ServeNoContent)) defer s.Close() diff --git a/net/dns/direct.go b/net/dns/direct.go index ec2e42e75176f..f6f2fd6019047 100644 --- a/net/dns/direct.go +++ b/net/dns/direct.go @@ -442,7 +442,9 @@ func (m *directManager) runFileWatcher() { if !ok { return } - if err := watchFile(m.ctx, "/etc/", resolvConf, m.checkForFileTrample); err != nil { + dir := m.fs.ActualPath(filepath.Dir(resolvConf)) + file := m.fs.ActualPath(resolvConf) + if err := watchFile(m.ctx, dir, file, m.checkForFileTrample); err != nil { // This is all best effort for now, so surface warnings to users. m.logf("dns: inotify: %s", err) } @@ -597,6 +599,19 @@ type wholeFileFS interface { ReadFile(name string) ([]byte, error) Remove(name string) error Rename(oldName, newName string) error + // ActualPath returns the real filesystem path for the given absolute + // logical path. All other methods in this interface accept logical + // paths (like "/etc/resolv.conf") and translate them internally; + // ActualPath exposes that same translation for callers that need + // the real path for use outside the interface (e.g. setting up an + // inotify watch on the correct directory). + // + // For directFS with an empty prefix (production), the input is + // returned unchanged ("/etc" → "/etc"). For directFS with a test + // prefix like "/tmp/test123", the prefix is joined + // ("/etc" → "/tmp/test123/etc"). For wslFS the input is returned + // unchanged, since paths are passed through to wsl.exe as-is. + ActualPath(name string) string Stat(name string) (isRegular bool, err error) Truncate(name string) error WriteFile(name string, contents []byte, perm os.FileMode) error @@ -613,6 +628,8 @@ type directFS struct { func (fs directFS) path(name string) string { return filepath.Join(fs.prefix, name) } +func (fs directFS) ActualPath(name string) string { return fs.path(name) } + func (fs directFS) Stat(name string) (isRegular bool, err error) { fi, err := os.Stat(fs.path(name)) if err != nil { diff --git a/net/dns/direct_linux_test.go b/net/dns/direct_linux_test.go index 8199b41f3b973..c053db178505c 100644 --- a/net/dns/direct_linux_test.go +++ b/net/dns/direct_linux_test.go @@ -7,21 +7,19 @@ package dns import ( "context" - "fmt" "net/netip" "os" "path/filepath" "testing" "testing/synctest" - - "github.com/illarion/gonotify/v3" + "time" "tailscale.com/util/dnsname" "tailscale.com/util/eventbus/eventbustest" ) func TestDNSTrampleRecovery(t *testing.T) { - HookWatchFile.Set(watchFile) + t.Cleanup(HookWatchFile.SetForTest(watchFile)) synctest.Test(t, func(t *testing.T) { tmp := t.TempDir() if err := os.MkdirAll(filepath.Join(tmp, "etc"), 0700); err != nil { @@ -77,33 +75,20 @@ search ts.net ts-dns.test }) } -// watchFile is generally copied from linuxtrample, but cancels the context -// after the first call to cb() after the first trample to end the test. +// watchFile is a test implementation of the file watcher that uses a timer +// instead of inotify. Real inotify (gonotify.NewDirWatcher) creates goroutines +// that block on real syscalls, which don't work inside synctest's fake-time +// bubble. Instead, we use a one-shot timer that synctest.Wait() will advance, +// triggering a callback to check for file trampling. func watchFile(ctx context.Context, dir, filename string, cb func()) error { - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - const events = gonotify.IN_ATTRIB | - gonotify.IN_CLOSE_WRITE | - gonotify.IN_CREATE | - gonotify.IN_DELETE | - gonotify.IN_MODIFY | - gonotify.IN_MOVE - - watcher, err := gonotify.NewDirWatcher(ctx, events, dir) - if err != nil { - return fmt.Errorf("NewDirWatcher: %w", err) - } - - for { - select { - case event := <-watcher.C: - if event.Name == filename { - cb() - cancel() - } - case <-ctx.Done(): - return ctx.Err() - } + timer := time.NewTimer(time.Millisecond) + defer timer.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + cb() } + <-ctx.Done() + return ctx.Err() } diff --git a/net/dns/dns_clone.go b/net/dns/dns_clone.go index 291f96ec2b51f..724e36dac86ed 100644 --- a/net/dns/dns_clone.go +++ b/net/dns/dns_clone.go @@ -6,7 +6,6 @@ package dns import ( - "maps" "net/netip" "tailscale.com/types/dnstype" @@ -34,8 +33,19 @@ func (src *Config) Clone() *Config { } if dst.Routes != nil { dst.Routes = map[dnsname.FQDN][]*dnstype.Resolver{} - for k := range src.Routes { - dst.Routes[k] = append([]*dnstype.Resolver{}, src.Routes[k]...) + for k, sv := range src.Routes { + if sv == nil { + dst.Routes[k] = nil + continue + } + dst.Routes[k] = make([]*dnstype.Resolver, len(sv)) + for i := range sv { + if sv[i] == nil { + dst.Routes[k][i] = nil + } else { + dst.Routes[k][i] = sv[i].Clone() + } + } } } dst.SearchDomains = append(src.SearchDomains[:0:0], src.SearchDomains...) @@ -45,7 +55,7 @@ func (src *Config) Clone() *Config { dst.Hosts[k] = append([]netip.Addr{}, src.Hosts[k]...) } } - dst.SubdomainHosts = maps.Clone(src.SubdomainHosts) + dst.SubdomainHosts = src.SubdomainHosts.Clone() return dst } diff --git a/net/dns/manager.go b/net/dns/manager.go index 889c542cf1f1d..8daa13cbc6981 100644 --- a/net/dns/manager.go +++ b/net/dns/manager.go @@ -431,7 +431,14 @@ func (m *Manager) compileConfig(cfg Config) (rcfg resolver.Config, ocfg OSConfig defaultRoutes = append(defaultRoutes, &dnstype.Resolver{Addr: ip.String()}) } rcfg.Routes["."] = defaultRoutes - ocfg.SearchDomains = append(ocfg.SearchDomains, baseCfg.SearchDomains...) + // Append base config search domains, but only if not already present. + // This prevents duplicates when GetBaseConfig() reads back domains that + // Tailscale itself previously wrote to resolv.conf. + for _, domain := range baseCfg.SearchDomains { + if !slices.Contains(ocfg.SearchDomains, domain) { + ocfg.SearchDomains = append(ocfg.SearchDomains, domain) + } + } } return rcfg, ocfg, nil diff --git a/net/dns/manager_darwin.go b/net/dns/manager_darwin.go index bb590aa4e7c14..90686b2466a93 100644 --- a/net/dns/manager_darwin.go +++ b/net/dns/manager_darwin.go @@ -5,7 +5,10 @@ package dns import ( "bytes" + "fmt" + "io/fs" "os" + "strings" "go4.org/mem" "tailscale.com/control/controlknobs" @@ -22,15 +25,22 @@ import ( // // The health tracker, bus and the knobs may be nil and are ignored on this platform. func NewOSConfigurator(logf logger.Logf, _ *health.Tracker, _ *eventbus.Bus, _ policyclient.Client, _ *controlknobs.Knobs, ifName string) (OSConfigurator, error) { - return &darwinConfigurator{logf: logf, ifName: ifName}, nil + return &darwinConfigurator{ + logf: logf, + ifName: ifName, + resolverDir: "/etc/resolver", + resolvConfPath: "/etc/resolv.conf", + }, nil } // darwinConfigurator is the tailscaled-on-macOS DNS OS configurator that // maintains the Split DNS nameserver entries pointing MagicDNS DNS suffixes // to 100.100.100.100 using the macOS /etc/resolver/$SUFFIX files. type darwinConfigurator struct { - logf logger.Logf - ifName string + logf logger.Logf + ifName string + resolverDir string // default "/etc/resolver" + resolvConfPath string // default "/etc/resolv.conf" } func (c *darwinConfigurator) Close() error { @@ -51,10 +61,16 @@ func (c *darwinConfigurator) SetDNS(cfg OSConfig) error { buf.WriteString("\n") } - if err := os.MkdirAll("/etc/resolver", 0755); err != nil { + if err := os.MkdirAll(c.resolverDir, 0755); err != nil { return err } + root, err := os.OpenRoot(c.resolverDir) + if err != nil { + return err + } + defer root.Close() + var keep map[string]bool // Add a dummy file to /etc/resolver with a "search ..." directive if we have @@ -70,7 +86,7 @@ func (c *darwinConfigurator) SetDNS(cfg OSConfig) error { sbuf.WriteString(string(d.WithoutTrailingDot())) } sbuf.WriteString("\n") - if err := os.WriteFile("/etc/resolver/"+searchFile, sbuf.Bytes(), 0644); err != nil { + if err := root.WriteFile(searchFile, sbuf.Bytes(), 0644); err != nil { return err } } @@ -78,15 +94,34 @@ func (c *darwinConfigurator) SetDNS(cfg OSConfig) error { for _, d := range cfg.MatchDomains { fileBase := string(d.WithoutTrailingDot()) mak.Set(&keep, fileBase, true) - fullPath := "/etc/resolver/" + fileBase - if err := os.WriteFile(fullPath, buf.Bytes(), 0644); err != nil { + if !isValidResolverFileName(fileBase) { + c.logf("[unexpected] invalid resolver domain %q with slashes or colons", fileBase) + return fmt.Errorf("invalid resolver domain %q: must not contain slashes or colons", fileBase) + } + + if err := root.WriteFile(fileBase, buf.Bytes(), 0644); err != nil { return err } } return c.removeResolverFiles(func(domain string) bool { return !keep[domain] }) } +func isValidResolverFileName(name string) bool { + // Verify that the filename doesn't contain any characters that + // might cause issues when used as a filename; os.Root is a + // defense against path traversal, but prefer a nice error here + // if we can. These aren't valid for domain names anyway. + if strings.Contains(name, "/") || strings.Contains(name, "\\") { + return false + } + + if strings.Contains(name, ":") { + return false + } + return true +} + // GetBaseConfig returns the current OS DNS configuration, extracting it from /etc/resolv.conf. // We should really be using the SystemConfiguration framework to get this information, as this // is not a stable public API, and is provided mostly as a compatibility effort with Unix @@ -95,9 +130,9 @@ func (c *darwinConfigurator) SetDNS(cfg OSConfig) error { func (c *darwinConfigurator) GetBaseConfig() (OSConfig, error) { cfg := OSConfig{} - resolvConf, err := resolvconffile.ParseFile("/etc/resolv.conf") + resolvConf, err := resolvconffile.ParseFile(c.resolvConfPath) if err != nil { - c.logf("failed to parse /etc/resolv.conf: %v", err) + c.logf("failed to parse %s: %v", c.resolvConfPath, err) return cfg, ErrGetBaseConfigNotSupported } @@ -113,7 +148,7 @@ func (c *darwinConfigurator) GetBaseConfig() (OSConfig, error) { if len(cfg.Nameservers) == 0 { // Log a warning in case we couldn't find any nameservers in /etc/resolv.conf. - c.logf("no nameservers found in /etc/resolv.conf, DNS resolution might fail") + c.logf("no nameservers found in %s, DNS resolution might fail", c.resolvConfPath) } return cfg, nil @@ -124,13 +159,19 @@ const macResolverFileHeader = "# Added by tailscaled\n" // removeResolverFiles deletes all files in /etc/resolver for which the shouldDelete // func returns true. func (c *darwinConfigurator) removeResolverFiles(shouldDelete func(domain string) bool) error { - dents, err := os.ReadDir("/etc/resolver") + root, err := os.OpenRoot(c.resolverDir) if os.IsNotExist(err) { return nil } if err != nil { return err } + defer root.Close() + + dents, err := fs.ReadDir(root.FS(), ".") + if err != nil { + return err + } for _, de := range dents { if !de.Type().IsRegular() { continue @@ -139,8 +180,7 @@ func (c *darwinConfigurator) removeResolverFiles(shouldDelete func(domain string if !shouldDelete(name) { continue } - fullPath := "/etc/resolver/" + name - contents, err := os.ReadFile(fullPath) + contents, err := root.ReadFile(name) if err != nil { if os.IsNotExist(err) { // race? continue @@ -150,7 +190,7 @@ func (c *darwinConfigurator) removeResolverFiles(shouldDelete func(domain string if !mem.HasPrefix(mem.B(contents), mem.S(macResolverFileHeader)) { continue } - if err := os.Remove(fullPath); err != nil { + if err := root.Remove(name); err != nil { return err } } diff --git a/net/dns/manager_darwin_test.go b/net/dns/manager_darwin_test.go new file mode 100644 index 0000000000000..8596f9575bc64 --- /dev/null +++ b/net/dns/manager_darwin_test.go @@ -0,0 +1,182 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package dns + +import ( + "errors" + "maps" + "net/netip" + "os" + "path/filepath" + "slices" + "testing" + + "tailscale.com/types/logger" + "tailscale.com/util/dnsname" +) + +func newTestConfigurator(t *testing.T) *darwinConfigurator { + t.Helper() + dir := t.TempDir() + + resolvConf := filepath.Join(dir, "resolv.conf") + if err := os.WriteFile(resolvConf, []byte("nameserver 8.8.8.8\n"), 0644); err != nil { + t.Fatal(err) + } + + resolverDir := filepath.Join(dir, "resolvers") + if err := os.Mkdir(resolverDir, 0755); err != nil { + t.Fatal(err) + } + + return &darwinConfigurator{ + logf: logger.Discard, + ifName: "utun99", + resolverDir: resolverDir, + resolvConfPath: resolvConf, + } +} + +func TestSetDNS(t *testing.T) { + c := newTestConfigurator(t) + + tests := []struct { + name string + cfg OSConfig + fileContents map[string]string // path -> expected file contents + }{ + { + name: "basic", + cfg: OSConfig{ + Nameservers: []netip.Addr{netip.MustParseAddr("100.100.100.100")}, + MatchDomains: []dnsname.FQDN{"example.com.", "ts.net."}, + }, + fileContents: map[string]string{ + "example.com": macResolverFileHeader + "nameserver 100.100.100.100\n", + "ts.net": macResolverFileHeader + "nameserver 100.100.100.100\n", + }, + }, + { + name: "SearchDomains", + cfg: OSConfig{ + Nameservers: []netip.Addr{netip.MustParseAddr("100.100.100.100")}, + SearchDomains: []dnsname.FQDN{"tail1234.ts.net."}, + MatchDomains: []dnsname.FQDN{"ts.net."}, + }, + fileContents: map[string]string{ + "ts.net": macResolverFileHeader + "nameserver 100.100.100.100\n", + "search.tailscale": macResolverFileHeader + "search tail1234.ts.net\n", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := c.SetDNS(tt.cfg); err != nil { + t.Fatalf("SetDNS failed: %v", err) + } + + // We want only the expected files in the resolverDir, + // and nothing else. + files, err := os.ReadDir(c.resolverDir) + if err != nil { + t.Fatalf("reading resolver directory: %v", err) + } + + var fileNames []string + for _, f := range files { + fileNames = append(fileNames, f.Name()) + } + + if len(files) != len(tt.fileContents) { + t.Fatalf("expected %d resolver files, got %d\ngot: %v\nwant: %v", + len(tt.fileContents), len(files), + fileNames, slices.Collect(maps.Keys(tt.fileContents)), + ) + } + + // Check each file's contents. + for domain, expected := range tt.fileContents { + path := filepath.Join(c.resolverDir, domain) + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("reading resolver file %q: %v", domain, err) + } + if string(data) != expected { + t.Errorf("resolver file %q contents mismatch:\ngot: %q\nwant: %q", domain, string(data), expected) + } + } + }) + } +} + +func TestSetDNS_PathTraversal(t *testing.T) { + c := newTestConfigurator(t) + + // Use a simple path traversal that tries to escape the resolver + // directory. With the previously-vulnerable code (os.WriteFile with string + // concatenation), this writes to the parent directory. With the + // fix (os.Root), this is rejected. + traversals := []dnsname.FQDN{ + "../evil.", + "../../evil.", + "sub/../../evil.", + } + + for _, traversal := range traversals { + cfg := OSConfig{ + Nameservers: []netip.Addr{netip.MustParseAddr("100.100.100.100")}, + MatchDomains: []dnsname.FQDN{traversal}, + } + + if err := c.SetDNS(cfg); err == nil { + t.Errorf("SetDNS with MatchDomain %q should have failed, but succeeded", traversal) + } + } + + // Verify no file named "evil" was written in the parent of resolverDir. + parent := filepath.Dir(c.resolverDir) + if fileExists(filepath.Join(parent, "evil")) { + t.Fatal("file 'evil' was written to parent directory via path traversal") + } +} + +func TestRemoveResolverFiles(t *testing.T) { + c := newTestConfigurator(t) + + // Write a tailscale-managed file. + managed := filepath.Join(c.resolverDir, "ts.net") + if err := os.WriteFile(managed, []byte(macResolverFileHeader+"nameserver 100.100.100.100\n"), 0644); err != nil { + t.Fatal(err) + } + + // Write a non-tailscale file that should be left alone. + unmanaged := filepath.Join(c.resolverDir, "other.conf") + if err := os.WriteFile(unmanaged, []byte("# not ours\nnameserver 8.8.8.8\n"), 0644); err != nil { + t.Fatal(err) + } + + // Remove all resolver files and verify that only the managed one is removed. + if err := c.removeResolverFiles(func(domain string) bool { return true }); err != nil { + t.Fatal(err) + } + + if fileExists(managed) { + t.Error("managed file should have been removed") + } + if !fileExists(unmanaged) { + t.Error("unmanaged file should still exist") + } +} + +func fileExists(path string) bool { + _, err := os.Stat(path) + if errors.Is(err, os.ErrNotExist) { + return false + } else if err == nil { + return true + } + + panic("unexpected error checking file existence: " + err.Error()) +} diff --git a/net/dns/manager_linux.go b/net/dns/manager_linux.go index e68b2e7f9e266..392b64ba989ca 100644 --- a/net/dns/manager_linux.go +++ b/net/dns/manager_linux.go @@ -380,7 +380,7 @@ func isLibnssResolveUsed(env newOSConfigEnv) error { if err != nil { return fmt.Errorf("reading /etc/resolv.conf: %w", err) } - for _, line := range strings.Split(string(bs), "\n") { + for line := range strings.SplitSeq(string(bs), "\n") { fields := strings.Fields(line) if len(fields) < 2 || fields[0] != "hosts:" { continue diff --git a/net/dns/manager_linux_test.go b/net/dns/manager_linux_test.go index d48fe23e70a8b..c3c99307ad62d 100644 --- a/net/dns/manager_linux_test.go +++ b/net/dns/manager_linux_test.go @@ -22,7 +22,7 @@ func TestLinuxDNSMode(t *testing.T) { want string }{ { - name: "no_obvious_resolv.conf_owner", + name: "no_obvious_resolvconf_owner", env: env(resolvDotConf("nameserver 10.0.0.1")), wantLog: "dns: [rc=unknown ret=direct]", want: "direct", @@ -153,7 +153,7 @@ func TestLinuxDNSMode(t *testing.T) { // alleged that it was managed by systemd-resolved, but it // was actually a completely static config file pointing // elsewhere. - name: "allegedly_resolved_but_not_in_resolv.conf", + name: "allegedly_resolved_but_not_in_resolvconf", env: env(resolvDotConf("# Managed by systemd-resolved", "nameserver 10.0.0.1")), wantLog: "dns: resolvedIsActuallyResolver error: resolv.conf doesn't point to systemd-resolved; points to [10.0.0.1]\n" + "dns: [rc=resolved resolved=not-in-use ret=direct]", @@ -163,7 +163,7 @@ func TestLinuxDNSMode(t *testing.T) { // We used to incorrectly decide that resolved wasn't in // charge when handed this (admittedly weird and bugged) // resolv.conf. - name: "resolved_with_duplicates_in_resolv.conf", + name: "resolved_with_duplicates_in_resolvconf", env: env( resolvDotConf( "# Managed by systemd-resolved", @@ -316,6 +316,7 @@ func (m memFS) Stat(name string) (isRegular bool, err error) { func (m memFS) Chmod(name string, mode os.FileMode) error { panic("TODO") } func (m memFS) Rename(oldName, newName string) error { panic("TODO") } func (m memFS) Remove(name string) error { panic("TODO") } +func (m memFS) ActualPath(name string) string { return name } func (m memFS) ReadFile(name string) ([]byte, error) { v, ok := m[name] if !ok { diff --git a/net/dns/openresolv.go b/net/dns/openresolv.go index c3aaf3a6948c8..2a4ed174e3f09 100644 --- a/net/dns/openresolv.go +++ b/net/dns/openresolv.go @@ -82,7 +82,7 @@ func (m openresolvManager) GetBaseConfig() (OSConfig, error) { // Remove the "tailscale" snippet from the list. args := []string{"-l"} - for _, f := range strings.Split(strings.TrimSpace(string(bs)), " ") { + for f := range strings.SplitSeq(strings.TrimSpace(string(bs)), " ") { if f == "tailscale" { continue } diff --git a/net/dns/resolver/forwarder.go b/net/dns/resolver/forwarder.go index 6fec32d6a2685..3f586b60f381a 100644 --- a/net/dns/resolver/forwarder.go +++ b/net/dns/resolver/forwarder.go @@ -13,6 +13,7 @@ import ( "errors" "fmt" "io" + "maps" "net" "net/http" "net/netip" @@ -41,8 +42,10 @@ import ( "tailscale.com/types/dnstype" "tailscale.com/types/logger" "tailscale.com/types/nettype" + "tailscale.com/types/views" "tailscale.com/util/cloudenv" "tailscale.com/util/dnsname" + "tailscale.com/util/mak" "tailscale.com/util/race" "tailscale.com/version" ) @@ -324,6 +327,19 @@ type forwarder struct { // resolver lookup. cloudHostFallback []resolverAndDelay + // schemes are the collection of registered URI scheme names that + // dynamically decide which resolver to use at the time of each query. The + // key is the scheme (the portion before the first `:`) and the value is a + // handler that determines where the current query should be sent. + // Use schemeCacheLocked() to get the current contents that can continue to + // be accessed once mu is released. This allows the (much more common) + // resolver code path to avoid repeated locking and unlocking. + // When modified, call invalidateSchemeCacheLocked() before unlocking mu. + schemes map[string]CustomSchemeHandler + // schemeCache is an immutable copy of schemes. Do not read directly, + // use schemeCacheLocked() which will regenerate its contents as needed. + schemeCache views.Map[string, CustomSchemeHandler] + // acceptDNS tracks the CorpDNS pref (--accept-dns) // This lets us skip health warnings if the forwarder receives inbound // queries directly - but we didn't configure it with any upstream resolvers. @@ -727,8 +743,7 @@ func (f *forwarder) send(ctx context.Context, fq *forwardQuery, rr resolverAndDe } // If we got a truncated UDP response, return that instead of an error. - var trErr truncatedResponseError - if errors.As(err, &trErr) { + if trErr, ok := errors.AsType[truncatedResponseError](err); ok { return trErr.res, nil } return nil, err @@ -740,6 +755,27 @@ type truncatedResponseError struct { func (tr truncatedResponseError) Error() string { return "response truncated" } +// rcodeResponseError is returned when an upstream DNS server responds with an +// rcode that is treated as a soft error (currently REFUSED and SERVFAIL). The +// response bytes are preserved so they can be returned to the client rather +// than synthesizing a new response. +type rcodeResponseError struct { + rcode dns.RCode + res []byte +} + +func (r rcodeResponseError) Error() string { return r.Unwrap().Error() } +func (r rcodeResponseError) Unwrap() error { + switch r.rcode { + case dns.RCodeRefused: + return errRefused + case dns.RCodeServerFailure: + return errServerFailure + } + return nil +} + +var errRefused = errors.New("response code indicates refusal") var errServerFailure = errors.New("response code indicates server issue") var errTxIDMismatch = errors.New("txid doesn't match") @@ -813,10 +849,16 @@ func (f *forwarder) sendUDP(ctx context.Context, fq *forwardQuery, rr resolverAn rcode := getRCode(out) // don't forward transient errors back to the client when the server fails - if rcode == dns.RCodeServerFailure { - f.logf("recv: response code indicating server failure: %d", rcode) + switch rcode { + case dns.RCodeServerFailure: + f.logf("sendUDP: response code indicating server failure: %d", rcode) metricDNSFwdUDPErrorServer.Add(1) - return nil, errServerFailure + return nil, rcodeResponseError{dns.RCodeServerFailure, out} + case dns.RCodeRefused: + // treat REFUSED as a soft error so other resolvers in the race can respond + f.logf("sendUDP: response code indicating refusal: %d", rcode) + metricDNSFwdUDPErrorRefused.Add(1) + return nil, rcodeResponseError{dns.RCodeRefused, out} } // Set the truncated bit if buffer was truncated during read and the flag isn't already set @@ -952,10 +994,16 @@ func (f *forwarder) sendTCP(ctx context.Context, fq *forwardQuery, rr resolverAn rcode := getRCode(out) // don't forward transient errors back to the client when the server fails - if rcode == dns.RCodeServerFailure { + switch rcode { + case dns.RCodeServerFailure: f.logf("sendTCP: response code indicating server failure: %d", rcode) metricDNSFwdTCPErrorServer.Add(1) - return nil, errServerFailure + return nil, rcodeResponseError{dns.RCodeServerFailure, out} + case dns.RCodeRefused: + // treat REFUSED as a soft error so other resolvers in the race can respond + f.logf("sendTCP: response code indicating refusal: %d", rcode) + metricDNSFwdTCPErrorRefused.Add(1) + return nil, rcodeResponseError{dns.RCodeRefused, out} } // TODO(andrew): do we need to do this? @@ -964,15 +1012,66 @@ func (f *forwarder) sendTCP(ctx context.Context, fq *forwardQuery, rr resolverAn return out, nil } +// applySchemes resolves any custom-scheme entries in rrs using the provided +// scheme handlers, returning the resulting slice. Entries whose handler returns +// an error or empty string are dropped. Entries with no registered scheme pass +// through unchanged. If schemes is nil, rrs is returned as-is. +func applySchemes(logf logger.Logf, rrs []resolverAndDelay, schemes views.Map[string, CustomSchemeHandler]) []resolverAndDelay { + if schemes.IsNil() { + return rrs + } + var result []resolverAndDelay + for i, rr := range rrs { + scheme, _, hasColon := strings.Cut(rr.name.Addr, ":") + handler, isCustom := schemes.GetOk(scheme) + if !hasColon || !isCustom { + if result != nil { + result = append(result, rr) + } + continue + } + // Avoid making a results slice in the common case where there + // are no custom scheme resolvers. + if result == nil { + result = make([]resolverAndDelay, i, len(rrs)) + copy(result, rrs) + } + newAddr, err := handler(rr.name.Addr) + if err != nil { + logf("error from custom scheme handler, skipping resolver : %v", err) + } + if err != nil || newAddr == "" { + continue + } + newResolver := *rr.name + newResolver.Addr = newAddr + result = append(result, resolverAndDelay{name: &newResolver, startDelay: rr.startDelay}) + } + // If we didn't have any custom schemes, return the original rrs. + if result == nil { + return rrs + } + return result +} + // resolvers returns the resolvers to use for domain. func (f *forwarder) resolvers(domain dnsname.FQDN) []resolverAndDelay { f.mu.Lock() routes := f.routes cloudHostFallback := f.cloudHostFallback + schemes := f.schemeCacheLocked() f.mu.Unlock() + for _, route := range routes { - if route.Suffix == "." || route.Suffix.Contains(domain) { - return route.Resolvers + if route.Suffix != "." && !route.Suffix.Contains(domain) { + continue + } + resolved := applySchemes(f.logf, route.Resolvers, schemes) + // If scheme resolution filtered out all resolvers from a non-empty + // route, fall through to the next matching route. If the resolvers + // were configured to be empty allow resolved to be empty. + if len(resolved) > 0 || len(route.Resolvers) == 0 { + return resolved } } return cloudHostFallback // or nil if no fallback @@ -989,6 +1088,39 @@ func (f *forwarder) GetUpstreamResolvers(name dnsname.FQDN) []*dnstype.Resolver return upstreamResolvers } +// RegisterCustomScheme adds a [CustomSchemeHandler] that is called to provide +// an updated address when a [dnstype.Resolver.Addr] uses that scheme. +func (f *forwarder) RegisterCustomScheme(scheme string, h CustomSchemeHandler) error { + f.mu.Lock() + defer f.mu.Unlock() + if _, ok := f.schemes[scheme]; ok { + return fmt.Errorf("scheme %q already registered", scheme) + } + f.invalidateSchemeCacheLocked() + mak.Set(&f.schemes, scheme, h) + return nil +} + +// invalidateSchemeCacheLocked clears f.schemeCache so that it will be rebuilt +// on the next call to f.schemeCacheLocked(). +func (f *forwarder) invalidateSchemeCacheLocked() { + f.schemeCache = views.Map[string, CustomSchemeHandler]{} +} + +// schemeCacheLocked returns an immutable copy of f.schemes that can be used +// after mu is unlocked. +func (f *forwarder) schemeCacheLocked() views.Map[string, CustomSchemeHandler] { + if !f.schemeCache.IsNil() { + return f.schemeCache + } + if f.schemes == nil { + return f.schemeCache // returns a nil view + } + // Regenerate the cache + f.schemeCache = views.MapOf(maps.Clone(f.schemes)) + return f.schemeCache +} + // forwardQuery is information and state about a forwarded DNS query that's // being sent to 1 or more upstreams. // @@ -1129,6 +1261,7 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo var firstErr error var numErr int + var sawNonRefused bool for { select { case v := <-resc: @@ -1148,32 +1281,56 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo if firstErr == nil { firstErr = err } + if !errors.Is(err, errRefused) { + sawNonRefused = true + } numErr++ if numErr == len(resolvers) { - if errors.Is(firstErr, errServerFailure) { - res, err := servfailResponse(query) - if err != nil { - f.logf("building servfail response: %v", err) + var res packet + if sawNonRefused { + // At least one server failed with SERVFAIL or a transport error + // (e.g. network failure, TxID mismatch, unsupported resolver type). + // All such errors map to SERVFAIL at the client level. + // Prefer returning the upstream SERVFAIL bytes from firstErr if + // available; otherwise synthesize a SERVFAIL response. Note the + // rcode guard: firstErr may be a REFUSED rcodeResponseError if it + // arrived before the SERVFAIL that set sawNonRefused. + if rcodeErr, ok := errors.AsType[rcodeResponseError](firstErr); ok && rcodeErr.rcode == dns.RCodeServerFailure { + res = packet{rcodeErr.res, query.family, query.addr} + } else { + r, err := servfailResponse(query) + if err != nil { + f.logf("building servfail response: %v", err) + return firstErr + } + res = r + } + } else { + // !sawNonRefused means every error was an rcodeResponseError with rcode REFUSED, + // so firstErr is guaranteed to wrap one. + rcodeErr, ok := errors.AsType[rcodeResponseError](firstErr) + if !ok { + f.logf("unexpected: all errors were REFUSED but firstErr is not rcodeResponseError: %v", firstErr) return firstErr } - - select { - case <-ctx.Done(): - metricDNSFwdErrorContext.Add(1) - metricDNSFwdErrorContextGotError.Add(1) - var resolverAddrs []string - for _, rr := range resolvers { - resolverAddrs = append(resolverAddrs, rr.name.Addr) - } - if f.acceptDNS { - f.health.SetUnhealthy(dnsForwarderFailing, health.Args{health.ArgDNSServers: strings.Join(resolverAddrs, ",")}) - } - case responseChan <- res: - if f.verboseFwd { - f.logf("forwarder response(%d, %v, %d) = %d, %v", fq.txid, typ, len(domain), len(res.bs), firstErr) - } - return nil + res = packet{rcodeErr.res, query.family, query.addr} + } + select { + case <-ctx.Done(): + metricDNSFwdErrorContext.Add(1) + metricDNSFwdErrorContextGotError.Add(1) + var resolverAddrs []string + for _, rr := range resolvers { + resolverAddrs = append(resolverAddrs, rr.name.Addr) + } + if f.acceptDNS { + f.health.SetUnhealthy(dnsForwarderFailing, health.Args{health.ArgDNSServers: strings.Join(resolverAddrs, ",")}) + } + case responseChan <- res: + if f.verboseFwd { + f.logf("forwarder response(%d, %v, %d) = %d, %v", fq.txid, typ, len(domain), len(res.bs), firstErr) } + return nil } return firstErr } diff --git a/net/dns/resolver/forwarder_test.go b/net/dns/resolver/forwarder_test.go index 6fd186c25a61c..ebe4041a69820 100644 --- a/net/dns/resolver/forwarder_test.go +++ b/net/dns/resolver/forwarder_test.go @@ -27,6 +27,7 @@ import ( "tailscale.com/net/tsdial" "tailscale.com/tstest" "tailscale.com/types/dnstype" + "tailscale.com/util/dnsname" "tailscale.com/util/eventbus/eventbustest" ) @@ -328,7 +329,7 @@ func runDNSServer(tb testing.TB, opts *testDNSServerOptions, response []byte, on udpLn *net.UDPConn err error ) - for try := 0; try < tries; try++ { + for range tries { if tcpLn != nil { tcpLn.Close() tcpLn = nil @@ -392,9 +393,7 @@ func runDNSServer(tb testing.TB, opts *testDNSServerOptions, response []byte, on var wg sync.WaitGroup if opts == nil || !opts.SkipTCP { - wg.Add(1) - go func() { - defer wg.Done() + wg.Go(func() { for { conn, err := tcpLn.Accept() if err != nil { @@ -402,7 +401,7 @@ func runDNSServer(tb testing.TB, opts *testDNSServerOptions, response []byte, on } go handleConn(conn) } - }() + }) } handleUDP := func(addr netip.AddrPort, req []byte) { @@ -413,9 +412,7 @@ func runDNSServer(tb testing.TB, opts *testDNSServerOptions, response []byte, on } if opts == nil || !opts.SkipUDP { - wg.Add(1) - go func() { - defer wg.Done() + wg.Go(func() { for { buf := make([]byte, 65535) n, addr, err := udpLn.ReadFromUDPAddrPort(buf) @@ -425,7 +422,7 @@ func runDNSServer(tb testing.TB, opts *testDNSServerOptions, response []byte, on buf = buf[:n] go handleUDP(addr, buf) } - }() + }) } tb.Cleanup(func() { @@ -684,7 +681,7 @@ func makeResponseOfSize(tb testing.TB, domain string, targetSize int, includeOPT var response []byte var err error - for attempt := 0; attempt < 10; attempt++ { + for range 10 { testBuilder := dns.NewBuilder(nil, dns.Header{ Response: true, Authoritative: true, @@ -1131,7 +1128,7 @@ func TestForwarderWithManyResolvers(t *testing.T) { }, }, { - name: "ServFail+Success", + name: "ServFail-and-Success", responses: [][]byte{ // All upstream servers fail except for one. makeTestResponse(t, domain, dns.RCodeServerFailure), makeTestResponse(t, domain, dns.RCodeServerFailure), @@ -1154,7 +1151,7 @@ func TestForwarderWithManyResolvers(t *testing.T) { }, }, { - name: "NXDomain+Success", + name: "NXDomain-and-Success", responses: [][]byte{ // All upstream servers returned NXDOMAIN except for one. makeTestResponse(t, domain, dns.RCodeNameError), makeTestResponse(t, domain, dns.RCodeNameError), @@ -1166,8 +1163,19 @@ func TestForwarderWithManyResolvers(t *testing.T) { }, }, { - name: "Refused", - responses: [][]byte{ // All upstream servers return different failures. + name: "AllRefused", + responses: [][]byte{ // All upstream servers return REFUSED. + makeTestResponse(t, domain, dns.RCodeRefused), + makeTestResponse(t, domain, dns.RCodeRefused), + makeTestResponse(t, domain, dns.RCodeRefused), + }, + wantResponses: [][]byte{ // When all refuse, return REFUSED to the client. + makeTestResponse(t, domain, dns.RCodeRefused), + }, + }, + { + name: "Refused-and-Success", + responses: [][]byte{ // Some upstream servers refuse, but one succeeds. makeTestResponse(t, domain, dns.RCodeRefused), makeTestResponse(t, domain, dns.RCodeRefused), makeTestResponse(t, domain, dns.RCodeRefused), @@ -1175,21 +1183,30 @@ func TestForwarderWithManyResolvers(t *testing.T) { makeTestResponse(t, domain, dns.RCodeRefused), makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), }, - wantResponses: [][]byte{ // Refused is not considered to be an error and can be forwarded. - makeTestResponse(t, domain, dns.RCodeRefused), + wantResponses: [][]byte{ // Refused is treated as a soft error; the Success response should win. makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), }, }, + { + name: "Refused-and-ServFail", + responses: [][]byte{ // Some servers refuse, at least one fails. + makeTestResponse(t, domain, dns.RCodeRefused), + makeTestResponse(t, domain, dns.RCodeServerFailure), + makeTestResponse(t, domain, dns.RCodeRefused), + }, + wantResponses: [][]byte{ // Any non-REFUSED failure triggers SERVFAIL regardless of arrival order. + makeTestResponse(t, domain, dns.RCodeServerFailure), + }, + }, { name: "MixFail", - responses: [][]byte{ // All upstream servers return different failures. + responses: [][]byte{ // Upstream servers return different failures. makeTestResponse(t, domain, dns.RCodeServerFailure), makeTestResponse(t, domain, dns.RCodeNameError), makeTestResponse(t, domain, dns.RCodeRefused), }, - wantResponses: [][]byte{ // Both NXDomain and Refused can be forwarded. + wantResponses: [][]byte{ // SERVFAIL and REFUSED are soft errors; NXDOMAIN wins. makeTestResponse(t, domain, dns.RCodeNameError), - makeTestResponse(t, domain, dns.RCodeRefused), }, }, } @@ -1301,3 +1318,210 @@ func TestForwarderVerboseLogs(t *testing.T) { t.Errorf("expected forwarding log, got:\n%s", logStr) } } + +// TestForwarderHealthOnContextExpiry verifies that when all resolvers fail and +// the context expires before the response can be sent, the health tracker is +// set unhealthy if and only if acceptDNS is true. +func TestForwarderHealthOnContextExpiry(t *testing.T) { + const domain = "health-test.example.com." + + tests := []struct { + name string + acceptDNS bool + wantUnhealthy bool + }{ + {"acceptDNS=true", true, true}, + {"acceptDNS=false", false, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + request := makeTestRequest(t, domain, dns.TypeA, 0) + logf := tstest.WhileTestRunningLogger(t) + bus := eventbustest.NewBus(t) + netMon, err := netmon.New(bus, logf) + if err != nil { + t.Fatal(err) + } + + var dialer tsdial.Dialer + dialer.SetNetMon(netMon) + dialer.SetBus(bus) + + ht := health.NewTracker(bus) + fwd := newForwarder(logf, netMon, nil, &dialer, ht, nil) + fwd.acceptDNS = tt.acceptDNS + + port1 := runDNSServer(t, nil, makeTestResponse(t, domain, dns.RCodeServerFailure), func(bool, []byte) {}) + port2 := runDNSServer(t, nil, makeTestResponse(t, domain, dns.RCodeServerFailure), func(bool, []byte) {}) + + resolvers := []resolverAndDelay{ + {name: &dnstype.Resolver{Addr: fmt.Sprintf("127.0.0.1:%d", port1)}}, + {name: &dnstype.Resolver{Addr: fmt.Sprintf("127.0.0.1:%d", port2)}}, + } + + rpkt := packet{ + bs: request, + family: "udp", + addr: netip.MustParseAddrPort("127.0.0.1:12345"), + } + + // Use an unbuffered responseChan so the send blocks, forcing the + // ctx.Done path and the SetUnhealthy call. + responseChan := make(chan packet) + + ctx, cancel := context.WithCancel(context.Background()) + // Cancel after DNS servers have had time to respond and their errors + // collected, leaving forwardWithDestChan blocked on responseChan. + go func() { + time.Sleep(50 * time.Millisecond) + cancel() + }() + + fwd.forwardWithDestChan(ctx, rpkt, responseChan, resolvers...) + + if got := ht.IsUnhealthy(dnsForwarderFailing); got != tt.wantUnhealthy { + t.Errorf("IsUnhealthy = %v, want %v", got, tt.wantUnhealthy) + } + }) + } +} + +func TestResolversCustomScheme(t *testing.T) { + t.Parallel() + tests := []struct { + name string + domain dnsname.FQDN + schemes map[string]CustomSchemeHandler + routes map[dnsname.FQDN][]*dnstype.Resolver + wantAddrs []string + }{ + { + name: "no-custom-scheme", + domain: "example.com.", + schemes: map[string]CustomSchemeHandler{}, + routes: map[dnsname.FQDN][]*dnstype.Resolver{ + "example.com.": { + {Addr: "192.168.1.1:53"}, + {Addr: "192.168.1.2:53"}, + }, + }, + wantAddrs: []string{"192.168.1.1:53", "192.168.1.2:53"}, + }, + { + name: "single-custom-scheme", + domain: "example.com.", + schemes: map[string]CustomSchemeHandler{ + "myscheme": func(string) (string, error) { return "1.2.3.4:53", nil }, + }, + routes: map[dnsname.FQDN][]*dnstype.Resolver{ + "example.com.": {{Addr: "myscheme:customKey"}}, + }, + wantAddrs: []string{"1.2.3.4:53"}, + }, + { + name: "with-other-resolvers", + domain: "example.com.", + schemes: map[string]CustomSchemeHandler{ + "myscheme": func(key string) (string, error) { return "1.2.3.4:53", nil }, + }, + routes: map[dnsname.FQDN][]*dnstype.Resolver{ + "example.com.": { + {Addr: "192.168.1.1:53"}, + {Addr: "myscheme:customKey"}, + {Addr: "192.168.1.2:53"}, + }, + }, + wantAddrs: []string{"192.168.1.1:53", "1.2.3.4:53", "192.168.1.2:53"}, + }, + { + name: "multiple-custom-schemes", + domain: "example.com.", + schemes: map[string]CustomSchemeHandler{ + "schemeOne": func(string) (string, error) { return "1.2.3.4:53", nil }, + "schemeTwo": func(string) (string, error) { return "5.6.7.8:53", nil }, + }, + routes: map[dnsname.FQDN][]*dnstype.Resolver{ + "example.com.": { + {Addr: "schemeOne:customKey"}, + {Addr: "schemeTwo:customKey"}, + }, + }, + wantAddrs: []string{"1.2.3.4:53", "5.6.7.8:53"}, + }, + { + name: "empty-string-means-no-resolver", + domain: "example.com.", + schemes: map[string]CustomSchemeHandler{ + "myscheme": func(string) (string, error) { return "", nil }, + }, + routes: map[dnsname.FQDN][]*dnstype.Resolver{ + "example.com.": { + {Addr: "192.168.1.1:53"}, + {Addr: "myscheme:customKey"}, + }, + }, + wantAddrs: []string{"192.168.1.1:53"}, + }, + { + name: "error-means-no-resolver", + domain: "example.com.", + schemes: map[string]CustomSchemeHandler{ + "myscheme": func(string) (string, error) { return "", fmt.Errorf("handler error") }, + }, + routes: map[dnsname.FQDN][]*dnstype.Resolver{ + "example.com.": { + {Addr: "192.168.1.1:53"}, + {Addr: "myscheme:customKey"}, + }, + }, + wantAddrs: []string{"192.168.1.1:53"}, + }, + { + // If the best-matching route yields no resolvers after scheme + // resolution, fall through to the next matching route. + name: "empty-scheme-result-falls-through-to-next-matching-route", + domain: "example.com.", + schemes: map[string]CustomSchemeHandler{ + "myscheme": func(string) (string, error) { return "", nil }, + }, + routes: map[dnsname.FQDN][]*dnstype.Resolver{ + "example.com.": {{Addr: "myscheme:customKey"}}, + ".": {{Addr: "192.168.1.1:53"}}, + }, + wantAddrs: []string{"192.168.1.1:53"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logf := tstest.WhileTestRunningLogger(t) + bus := eventbustest.NewBus(t) + netMon, err := netmon.New(bus, logf) + if err != nil { + t.Fatal(err) + } + var dialer tsdial.Dialer + dialer.SetNetMon(netMon) + dialer.SetBus(bus) + + fwd := newForwarder(logf, netMon, nil, &dialer, health.NewTracker(bus), nil) + for scheme, handler := range tt.schemes { + if err := fwd.RegisterCustomScheme(scheme, handler); err != nil { + t.Fatal(err) + } + } + + fwd.setRoutes(tt.routes, false) + + got := fwd.resolvers(tt.domain) + var gotAddrs []string + for _, r := range got { + gotAddrs = append(gotAddrs, r.name.Addr) + } + if !slices.Equal(gotAddrs, tt.wantAddrs) { + t.Errorf("got %v, want %v", gotAddrs, tt.wantAddrs) + } + }) + } +} diff --git a/net/dns/resolver/tsdns.go b/net/dns/resolver/tsdns.go index d0601de7bfe25..4b2db5705ea46 100644 --- a/net/dns/resolver/tsdns.go +++ b/net/dns/resolver/tsdns.go @@ -16,7 +16,7 @@ import ( "net/netip" "os" "runtime" - "sort" + "slices" "strconv" "strings" "sync" @@ -172,7 +172,7 @@ func WriteRoutes(w *bufio.Writer, routes map[dnsname.FQDN][]*dnstype.Resolver) { } kk = append(kk, k) } - sort.Slice(kk, func(i, j int) bool { return kk[i] < kk[j] }) + slices.Sort(kk) w.WriteByte('{') for i, k := range kk { if i > 0 { @@ -293,6 +293,18 @@ func (r *Resolver) SetConfig(cfg Config) error { return nil } +// CustomSchemeHandler takes a URI (retrieved from [dnstype.Resolver.Addr]) and +// returns an updated URI to use for the current query. The result is only valid +// for right now and may change over time. +type CustomSchemeHandler func(addr string) (newAddr string, err error) + +// RegisterCustomScheme adds a [CustomSchemaHandler] that is called to provide +// an updated address to the forwarder when a [dnstype.Resolver.Addr] uses that +// scheme. +func (r *Resolver) RegisterCustomScheme(scheme string, h CustomSchemeHandler) error { + return r.forwarder.RegisterCustomScheme(scheme, h) +} + // Close shuts down the resolver and ensures poll goroutines have exited. // The Resolver cannot be used again after Close is called. func (r *Resolver) Close() { @@ -1402,21 +1414,23 @@ var ( metricDNSFwdErrorType = clientmetric.NewCounter("dns_query_fwd_error_type") metricDNSFwdTruncated = clientmetric.NewCounter("dns_query_fwd_truncated") - metricDNSFwdUDP = clientmetric.NewCounter("dns_query_fwd_udp") // on entry - metricDNSFwdUDPWrote = clientmetric.NewCounter("dns_query_fwd_udp_wrote") // sent UDP packet - metricDNSFwdUDPErrorWrite = clientmetric.NewCounter("dns_query_fwd_udp_error_write") - metricDNSFwdUDPErrorServer = clientmetric.NewCounter("dns_query_fwd_udp_error_server") - metricDNSFwdUDPErrorTxID = clientmetric.NewCounter("dns_query_fwd_udp_error_txid") - metricDNSFwdUDPErrorRead = clientmetric.NewCounter("dns_query_fwd_udp_error_read") - metricDNSFwdUDPSuccess = clientmetric.NewCounter("dns_query_fwd_udp_success") - - metricDNSFwdTCP = clientmetric.NewCounter("dns_query_fwd_tcp") // on entry - metricDNSFwdTCPWrote = clientmetric.NewCounter("dns_query_fwd_tcp_wrote") // sent TCP packet - metricDNSFwdTCPErrorWrite = clientmetric.NewCounter("dns_query_fwd_tcp_error_write") - metricDNSFwdTCPErrorServer = clientmetric.NewCounter("dns_query_fwd_tcp_error_server") - metricDNSFwdTCPErrorTxID = clientmetric.NewCounter("dns_query_fwd_tcp_error_txid") - metricDNSFwdTCPErrorRead = clientmetric.NewCounter("dns_query_fwd_tcp_error_read") - metricDNSFwdTCPSuccess = clientmetric.NewCounter("dns_query_fwd_tcp_success") + metricDNSFwdUDP = clientmetric.NewCounter("dns_query_fwd_udp") // on entry + metricDNSFwdUDPWrote = clientmetric.NewCounter("dns_query_fwd_udp_wrote") // sent UDP packet + metricDNSFwdUDPErrorWrite = clientmetric.NewCounter("dns_query_fwd_udp_error_write") + metricDNSFwdUDPErrorServer = clientmetric.NewCounter("dns_query_fwd_udp_error_server") + metricDNSFwdUDPErrorRefused = clientmetric.NewCounter("dns_query_fwd_udp_error_refused") + metricDNSFwdUDPErrorTxID = clientmetric.NewCounter("dns_query_fwd_udp_error_txid") + metricDNSFwdUDPErrorRead = clientmetric.NewCounter("dns_query_fwd_udp_error_read") + metricDNSFwdUDPSuccess = clientmetric.NewCounter("dns_query_fwd_udp_success") + + metricDNSFwdTCP = clientmetric.NewCounter("dns_query_fwd_tcp") // on entry + metricDNSFwdTCPWrote = clientmetric.NewCounter("dns_query_fwd_tcp_wrote") // sent TCP packet + metricDNSFwdTCPErrorWrite = clientmetric.NewCounter("dns_query_fwd_tcp_error_write") + metricDNSFwdTCPErrorServer = clientmetric.NewCounter("dns_query_fwd_tcp_error_server") + metricDNSFwdTCPErrorRefused = clientmetric.NewCounter("dns_query_fwd_tcp_error_refused") + metricDNSFwdTCPErrorTxID = clientmetric.NewCounter("dns_query_fwd_tcp_error_txid") + metricDNSFwdTCPErrorRead = clientmetric.NewCounter("dns_query_fwd_tcp_error_read") + metricDNSFwdTCPSuccess = clientmetric.NewCounter("dns_query_fwd_tcp_success") metricDNSFwdDoH = clientmetric.NewCounter("dns_query_fwd_doh") metricDNSFwdDoHErrorStatus = clientmetric.NewCounter("dns_query_fwd_doh_error_status") diff --git a/net/dns/resolver/tsdns_test.go b/net/dns/resolver/tsdns_test.go index 8ee22dd1384c0..381ceedb4e194 100644 --- a/net/dns/resolver/tsdns_test.go +++ b/net/dns/resolver/tsdns_test.go @@ -1557,15 +1557,13 @@ func TestServfail(t *testing.T) { t.Fatalf("err = %v, want nil", err) } + // The upstream server's SERVFAIL bytes are returned directly. wantPkt := []byte{ 0x00, 0x00, // transaction id: 0 - 0x84, 0x02, // flags: response, authoritative, error: servfail - 0x00, 0x01, // one question + 0x00, 0x02, // flags: error: servfail + 0x00, 0x00, // no questions (upstream sent a minimal response) 0x00, 0x00, // no answers 0x00, 0x00, 0x00, 0x00, // no authority or additional RRs - // Question: - 0x04, 0x74, 0x65, 0x73, 0x74, 0x04, 0x73, 0x69, 0x74, 0x65, 0x00, // name - 0x00, 0x01, 0x00, 0x01, // type A, class IN } if !bytes.Equal(pkt, wantPkt) { diff --git a/net/dns/wsl_windows.go b/net/dns/wsl_windows.go index c2400746b8a2d..b0e62170b2e04 100644 --- a/net/dns/wsl_windows.go +++ b/net/dns/wsl_windows.go @@ -148,6 +148,8 @@ type wslFS struct { distro string } +func (fs wslFS) ActualPath(name string) string { return name } + func (fs wslFS) Stat(name string) (isRegular bool, err error) { err = wslRun(fs.cmd("test", "-f", name)) if ee, _ := err.(*exec.ExitError); ee != nil { @@ -172,8 +174,7 @@ func (fs wslFS) Truncate(name string) error { return fs.WriteFile(name, nil, 064 func (fs wslFS) ReadFile(name string) ([]byte, error) { b, err := wslCombinedOutput(fs.cmd("cat", "--", name)) - var ee *exec.ExitError - if errors.As(err, &ee) && ee.ExitCode() == 1 { + if ee, ok := errors.AsType[*exec.ExitError](err); ok && ee.ExitCode() == 1 { return nil, os.ErrNotExist } return b, err diff --git a/net/ipset/ipset.go b/net/ipset/ipset.go index 92cec9d0be854..140814d96676c 100644 --- a/net/ipset/ipset.go +++ b/net/ipset/ipset.go @@ -20,13 +20,6 @@ func FalseContainsIPFunc() func(ip netip.Addr) bool { func emptySet(ip netip.Addr) bool { return false } -func bartLookup(t *bart.Table[struct{}]) func(netip.Addr) bool { - return func(ip netip.Addr) bool { - _, ok := t.Lookup(ip) - return ok - } -} - func prefixContainsLoop(addrs []netip.Prefix) func(netip.Addr) bool { return func(ip netip.Addr) bool { for _, p := range addrs { @@ -81,11 +74,11 @@ func NewContainsIPFunc(addrs views.Slice[netip.Prefix]) func(ip netip.Addr) bool } pathForTest("bart") // Built a bart table. - t := &bart.Table[struct{}]{} + t := &bart.Lite{} for _, p := range addrs.All() { - t.Insert(p, struct{}{}) + t.Insert(p) } - return bartLookup(t) + return t.Contains } // Fast paths for 1 and 2 IPs: if addrs.Len() == 1 { diff --git a/net/netcheck/netcheck.go b/net/netcheck/netcheck.go index ebcdc4eaca4e3..a64c358c5c09f 100644 --- a/net/netcheck/netcheck.go +++ b/net/netcheck/netcheck.go @@ -545,7 +545,7 @@ func makeProbePlanInitial(dm *tailcfg.DERPMap, ifState *netmon.State) (plan prob var p4 []probe var p6 []probe - for try := 0; try < 3; try++ { + for try := range 3 { n := reg.Nodes[try%len(reg.Nodes)] delay := time.Duration(try) * defaultInitialRetransmitTime if n.IPv4 != "none" && ((ifState.HaveV4 && nodeMight4(n)) || n.IsTestNode()) { @@ -975,13 +975,11 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap, opts *GetRe // need to close the underlying Pinger after a timeout // or when all ICMP probes are done, regardless of // whether the HTTPS probes have finished. - wg.Add(1) - go func() { - defer wg.Done() + wg.Go(func() { if err := c.measureAllICMPLatency(ctx, rs, need); err != nil { c.logf("[v1] measureAllICMPLatency: %v", err) } - }() + }) } wg.Add(len(need)) c.logf("netcheck: UDP is blocked, trying HTTPS") @@ -1072,9 +1070,7 @@ func (c *Client) runHTTPOnlyChecks(ctx context.Context, last *Report, rs *report if len(rg.Nodes) == 0 { continue } - wg.Add(1) - go func() { - defer wg.Done() + wg.Go(func() { node := rg.Nodes[0] req, _ := http.NewRequestWithContext(ctx, "HEAD", "https://"+node.HostName+"/derp/probe", nil) // One warm-up one to get HTTP connection set @@ -1099,7 +1095,7 @@ func (c *Client) runHTTPOnlyChecks(ctx context.Context, last *Report, rs *report } d := c.timeNow().Sub(t0) rs.addNodeLatency(node, netip.AddrPort{}, d) - }() + }) } wg.Wait() return nil diff --git a/net/netcheck/netcheck_test.go b/net/netcheck/netcheck_test.go index ab7f58febcb3b..0fd3460fa57af 100644 --- a/net/netcheck/netcheck_test.go +++ b/net/netcheck/netcheck_test.go @@ -42,8 +42,7 @@ func TestBasic(t *testing.T) { c := newTestClient(t) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() if err := c.Standalone(ctx, "127.0.0.1:0"); err != nil { t.Fatal(err) @@ -124,8 +123,7 @@ func TestWorksWhenUDPBlocked(t *testing.T) { c := newTestClient(t) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() r, err := c.GetReport(ctx, dm, nil) if err != nil { @@ -1000,7 +998,7 @@ func TestNodeAddrResolve(t *testing.T) { } t.Logf("got IPv6 addr: %v", ap) }) - t.Run("IPv6 Failure", func(t *testing.T) { + t.Run("IPv6-Failure", func(t *testing.T) { ap, ok := c.nodeAddrPort(ctx, dnV4Only, dn.STUNPort, probeIPv6) if ok { t.Fatalf("expected no addr but got: %v", ap) @@ -1038,8 +1036,7 @@ func TestNoUDPNilGetReportOpts(t *testing.T) { } c := newTestClient(t) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() r, err := c.GetReport(ctx, dm, nil) if err != nil { diff --git a/net/neterror/neterror_linux.go b/net/neterror/neterror_linux.go index 9add4fd1d213c..a99452de5d3b8 100644 --- a/net/neterror/neterror_linux.go +++ b/net/neterror/neterror_linux.go @@ -12,8 +12,7 @@ import ( func init() { shouldDisableUDPGSO = func(err error) bool { - var serr *os.SyscallError - if errors.As(err, &serr) { + if serr, ok := errors.AsType[*os.SyscallError](err); ok { // EIO is returned by udp_send_skb() if the device driver does not // have tx checksumming enabled, which is a hard requirement of // UDP_SEGMENT. See: diff --git a/net/netmon/loghelper_test.go b/net/netmon/loghelper_test.go index aec5206443aa4..64231e595120a 100644 --- a/net/netmon/loghelper_test.go +++ b/net/netmon/loghelper_test.go @@ -64,7 +64,7 @@ func syncTestLinkChangeLogLimiter(t *testing.T) { // InjectEvent doesn't work because it's not a major event, so we // instead inject the event ourselves. injector := eventbustest.NewInjector(t, bus) - cd, err := NewChangeDelta(nil, &State{}, true, true) + cd, err := NewChangeDelta(nil, &State{}, 0, true) if err != nil { t.Fatal(err) } diff --git a/net/netmon/netmon.go b/net/netmon/netmon.go index 1d51379d86e31..a7120cdd3375a 100644 --- a/net/netmon/netmon.go +++ b/net/netmon/netmon.go @@ -13,6 +13,7 @@ import ( "net/netip" "runtime" "slices" + "strings" "sync" "time" @@ -31,6 +32,15 @@ import ( // us check the wall time sooner than this. const pollWallTimeInterval = 15 * time.Second +// majorTimeJumpThreshold is the minimum sleep duration that warrants +// treating a time jump as a major event requiring socket rebinding, +// even if the interface state appears unchanged. After a long sleep, +// NAT mappings are likely stale and DHCP leases may have expired +// (the renewal happens after wake, so local state may not yet reflect it). +// Short sleeps (e.g., macOS DarkWake maintenance cycles of ~55s) should +// not trigger rebinding if the network state is unchanged. +const majorTimeJumpThreshold = 10 * time.Minute + // message represents a message returned from an osMon. type message interface { // Ignore is whether we should ignore this message. @@ -67,18 +77,18 @@ type Monitor struct { stop chan struct{} // closed on Stop static bool // static Monitor that doesn't actually monitor - mu syncs.Mutex // guards all following fields - cbs set.HandleSet[ChangeFunc] - ifState *State - gwValid bool // whether gw and gwSelfIP are valid - gw netip.Addr // our gateway's IP - gwSelfIP netip.Addr // our own IP address (that corresponds to gw) - started bool - closed bool - goroutines sync.WaitGroup - wallTimer *time.Timer // nil until Started; re-armed AfterFunc per tick - lastWall time.Time - timeJumped bool // whether we need to send a changed=true after a big time jump + mu syncs.Mutex // guards all following fields + cbs set.HandleSet[ChangeFunc] + ifState *State + gwValid bool // whether gw and gwSelfIP are valid + gw netip.Addr // our gateway's IP + gwSelfIP netip.Addr // our own IP address (that corresponds to gw) + started bool + closed bool + goroutines sync.WaitGroup + wallTimer *time.Timer // nil until Started; re-armed AfterFunc per tick + lastWall time.Time + jumpDuration time.Duration // wall-clock time elapsed during detected time jump; 0 if no time jump observed since reset } // ChangeFunc is a callback function registered with Monitor that's called when the @@ -97,10 +107,12 @@ type ChangeDelta struct { // It is always non-nil. new *State - // TimeJumped is whether there was a big jump in wall time since the last - // time we checked. This is a hint that a sleeping device might have - // come out of sleep. - TimeJumped bool + // JumpDuration is non-zero when a wall-clock time jump was detected, + // indicating the machine likely just woke from sleep. It is approximately + // how long the machine was asleep (the wall-clock delta since the last + // check, not an exact sleep measurement). Use TimeJumped() to check + // whether a time jump occurred. + JumpDuration time.Duration DefaultRouteInterface string @@ -121,19 +133,28 @@ type ChangeDelta struct { RebindLikelyRequired bool } +// TimeJumped reports whether a wall-clock time jump was detected, +// indicating the machine likely just woke from sleep. When true, +// JumpDuration contains the approximate duration. +func (cd *ChangeDelta) TimeJumped() bool { + return cd.JumpDuration > 0 +} + // CurrentState returns the current (new) state after the change. func (cd *ChangeDelta) CurrentState() *State { return cd.new } // NewChangeDelta builds a ChangeDelta and eagerly computes the cached fields. +// jumpDuration, if non-zero, indicates a wall-clock time jump was detected +// (the machine likely woke from sleep) and is the approximate duration of the jump. // forceViability, if true, forces DefaultInterfaceMaybeViable to be true regardless of the // actual state of the default interface. This is useful in testing. -func NewChangeDelta(old, new *State, timeJumped bool, forceViability bool) (*ChangeDelta, error) { +func NewChangeDelta(old, new *State, jumpDuration time.Duration, forceViability bool) (*ChangeDelta, error) { cd := ChangeDelta{ - old: old, - new: new, - TimeJumped: timeJumped, + old: old, + new: new, + JumpDuration: jumpDuration, } if cd.new == nil { @@ -165,10 +186,18 @@ func NewChangeDelta(old, new *State, timeJumped bool, forceViability bool) (*Cha cd.DefaultInterfaceMaybeViable = true } - // Compute rebind requirement. The default interface needs to be viable and + // Compute rebind requirement. The default interface needs to be viable and // one of the other conditions needs to be true. + // + // Short time jumps (e.g., macOS DarkWake maintenance cycles of ~55s) are + // excluded — if the network state is unchanged after a brief sleep, there's + // no reason to rebind. However, a major time jump (over majorTimeJumpThreshold) + // warrants a rebind even if the local state looks the same, because NAT + // mappings are likely stale and DHCP leases may have changed (the renewal + // happens after wake, so local state may not yet reflect it). + majorTimeJump := cd.JumpDuration >= majorTimeJumpThreshold cd.RebindLikelyRequired = (cd.old == nil || - cd.TimeJumped || + majorTimeJump || cd.DefaultInterfaceChanged || cd.InterfaceIPsChanged || cd.IsLessExpensive || @@ -181,7 +210,39 @@ func NewChangeDelta(old, new *State, timeJumped bool, forceViability bool) (*Cha // StateDesc returns a description of the old and new states for logging. func (cd *ChangeDelta) StateDesc() string { - return fmt.Sprintf("old: %v new: %v", cd.old, cd.new) + var sb strings.Builder + fmt.Fprintf(&sb, "old: %v new: %v", cd.old, cd.new) + if cd.old != nil && cd.new != nil { + if diff := cd.old.InterfaceDiff(cd.new); diff != "" { + fmt.Fprintf(&sb, " diff: %s", diff) + } + } + if cd.RebindLikelyRequired { + var reasons []string + if cd.old == nil { + reasons = append(reasons, "initial-state") + } + if cd.TimeJumped() { + reasons = append(reasons, fmt.Sprintf("time-jumped(%v)", cd.JumpDuration.Round(time.Second))) + } + if cd.DefaultInterfaceChanged { + reasons = append(reasons, "default-if-changed") + } + if cd.InterfaceIPsChanged { + reasons = append(reasons, "ips-changed") + } + if cd.IsLessExpensive { + reasons = append(reasons, "less-expensive") + } + if cd.HasPACOrProxyConfigChanged { + reasons = append(reasons, "pac-proxy-changed") + } + if cd.AvailableProtocolsChanged { + reasons = append(reasons, "protocols-changed") + } + fmt.Fprintf(&sb, " rebind-reason=[%s]", strings.Join(reasons, ",")) + } + return sb.String() } // InterfaceIPDisappeared reports whether the given IP address exists on any interface @@ -245,7 +306,12 @@ func (cd *ChangeDelta) isInterestingInterfaceChange() bool { } newIps = filterRoutableIPs(newIps) - if !oldInterface.Equal(newInterface) || !prefixesEqual(oldIps, newIps) { + // Only consider routable IP changes and up/down state transitions + // as interesting. Transient metadata changes (Flags like FlagRunning, + // MTU, etc.) should not trigger a major link change, as they cause + // false "major" events on macOS and Windows when the OS notifies us + // of interface changes that don't affect connectivity. + if oldInterface.IsUp() != newInterface.IsUp() || !prefixesEqual(oldIps, newIps) { return true } } @@ -277,8 +343,8 @@ func (cd *ChangeDelta) isInterestingInterfaceChange() bool { } oldIps = filterRoutableIPs(oldIps) - // The interface's IPs, Name, MTU, etc have changed. This is definitely interesting. - if !newInterface.Equal(oldInterface) || !prefixesEqual(oldIps, newIps) { + // Only consider routable IP changes and up/down state transitions. + if newInterface.IsUp() != oldInterface.IsUp() || !prefixesEqual(oldIps, newIps) { return true } } @@ -574,7 +640,8 @@ func (m *Monitor) handlePotentialChange(newState *State, forceCallbacks bool) { return } - delta, err := NewChangeDelta(oldState, newState, timeJumped, false) + jumpDuration := m.jumpDuration + delta, err := NewChangeDelta(oldState, newState, jumpDuration, false) if err != nil { m.logf("[unexpected] error creating ChangeDelta: %v", err) return @@ -587,12 +654,13 @@ func (m *Monitor) handlePotentialChange(newState *State, forceCallbacks bool) { // See if we have a queued or new time jump signal. if timeJumped { m.resetTimeJumpedLocked() + m.logf("time jump detected (slept %v), probably wake from sleep", jumpDuration.Round(time.Second)) } metricChange.Add(1) if delta.RebindLikelyRequired { metricChangeMajor.Add(1) } - if delta.TimeJumped { + if delta.TimeJumped() { metricChangeTimeJump.Add(1) } m.changed.Publish(*delta) @@ -654,14 +722,14 @@ func (m *Monitor) checkWallTimeAdvanceLocked() bool { panic("unreachable") // if callers are correct } now := wallTime() - if now.Sub(m.lastWall) > pollWallTimeInterval*3/2 { - m.timeJumped = true // it is reset by debounce. + if elapsed := now.Sub(m.lastWall); elapsed > pollWallTimeInterval*3/2 { + m.jumpDuration = elapsed } m.lastWall = now - return m.timeJumped + return m.jumpDuration != 0 } // resetTimeJumpedLocked consumes the signal set by checkWallTimeAdvanceLocked. func (m *Monitor) resetTimeJumpedLocked() { - m.timeJumped = false + m.jumpDuration = 0 } diff --git a/net/netmon/netmon_linux_test.go b/net/netmon/netmon_linux_test.go index c6c12e850f3fe..c4e59059ac4c4 100644 --- a/net/netmon/netmon_linux_test.go +++ b/net/netmon/netmon_linux_test.go @@ -49,7 +49,7 @@ func TestIgnoreDuplicateNEWADDR(t *testing.T) { return msg } - t.Run("suppress duplicate NEWADDRs", func(t *testing.T) { + t.Run("suppress-duplicate-NEWADDRs", func(t *testing.T) { c := nlConn{ buffered: []netlink.Message{ newAddrMsg(1, "192.168.0.5", unix.RTM_NEWADDR), @@ -69,7 +69,7 @@ func TestIgnoreDuplicateNEWADDR(t *testing.T) { } }) - t.Run("do not suppress after DELADDR", func(t *testing.T) { + t.Run("no-suppress-after-DELADDR", func(t *testing.T) { c := nlConn{ buffered: []netlink.Message{ newAddrMsg(1, "192.168.0.5", unix.RTM_NEWADDR), diff --git a/net/netmon/netmon_test.go b/net/netmon/netmon_test.go index 97c203274cd8f..a3ba4e03e17de 100644 --- a/net/netmon/netmon_test.go +++ b/net/netmon/netmon_test.go @@ -473,6 +473,98 @@ func TestRebindRequired(t *testing.T) { }, want: false, }, + { + name: "interface-flags-changed-no-ip-change", + s1: &State{ + DefaultRouteInterface: "en0", + Interface: map[string]Interface{ + "en0": {Interface: &net.Interface{ + Name: "en0", + Flags: net.FlagUp | net.FlagBroadcast | net.FlagMulticast | net.FlagRunning, + }}, + }, + InterfaceIPs: map[string][]netip.Prefix{ + "en0": {netip.MustParsePrefix("10.0.0.12/24")}, + }, + HaveV4: true, + }, + s2: &State{ + DefaultRouteInterface: "en0", + Interface: map[string]Interface{ + "en0": {Interface: &net.Interface{ + Name: "en0", + Flags: net.FlagUp | net.FlagBroadcast | net.FlagMulticast, // FlagRunning removed + }}, + }, + InterfaceIPs: map[string][]netip.Prefix{ + "en0": {netip.MustParsePrefix("10.0.0.12/24")}, + }, + HaveV4: true, + }, + want: false, + }, + { + name: "interface-mtu-changed-no-ip-change", + s1: &State{ + DefaultRouteInterface: "en0", + Interface: map[string]Interface{ + "en0": {Interface: &net.Interface{ + Name: "en0", + Flags: net.FlagUp | net.FlagBroadcast | net.FlagMulticast | net.FlagRunning, + MTU: 1500, + }}, + }, + InterfaceIPs: map[string][]netip.Prefix{ + "en0": {netip.MustParsePrefix("10.0.0.12/24")}, + }, + HaveV4: true, + }, + s2: &State{ + DefaultRouteInterface: "en0", + Interface: map[string]Interface{ + "en0": {Interface: &net.Interface{ + Name: "en0", + Flags: net.FlagUp | net.FlagBroadcast | net.FlagMulticast | net.FlagRunning, + MTU: 9000, + }}, + }, + InterfaceIPs: map[string][]netip.Prefix{ + "en0": {netip.MustParsePrefix("10.0.0.12/24")}, + }, + HaveV4: true, + }, + want: false, + }, + { + name: "interface-went-down", + s1: &State{ + DefaultRouteInterface: "en0", + Interface: map[string]Interface{ + "en0": {Interface: &net.Interface{ + Name: "en0", + Flags: net.FlagUp | net.FlagBroadcast | net.FlagMulticast | net.FlagRunning, + }}, + }, + InterfaceIPs: map[string][]netip.Prefix{ + "en0": {netip.MustParsePrefix("10.0.0.12/24")}, + }, + HaveV4: true, + }, + s2: &State{ + DefaultRouteInterface: "en0", + Interface: map[string]Interface{ + "en0": {Interface: &net.Interface{ + Name: "en0", + Flags: net.FlagBroadcast | net.FlagMulticast, // FlagUp removed + }}, + }, + InterfaceIPs: map[string][]netip.Prefix{ + "en0": {netip.MustParsePrefix("10.0.0.12/24")}, + }, + HaveV4: true, + }, + want: true, + }, } withIsInterestingInterface(t, func(ni Interface, pfxs []netip.Prefix) bool { @@ -498,7 +590,7 @@ func TestRebindRequired(t *testing.T) { } SetTailscaleInterfaceProps(tt.tsIfName, 1) - cd, err := NewChangeDelta(tt.s1, tt.s2, false, true) + cd, err := NewChangeDelta(tt.s1, tt.s2, 0, true) if err != nil { t.Fatalf("NewChangeDelta error: %v", err) } @@ -510,6 +602,71 @@ func TestRebindRequired(t *testing.T) { } } +func TestTimeJumpedDoesNotTriggerRebind(t *testing.T) { + s := &State{ + DefaultRouteInterface: "en0", + Interface: map[string]Interface{ + "en0": {Interface: &net.Interface{ + Name: "en0", + Flags: net.FlagUp | net.FlagBroadcast | net.FlagMulticast | net.FlagRunning, + }}, + }, + InterfaceIPs: map[string][]netip.Prefix{ + "en0": {netip.MustParsePrefix("10.0.0.12/24")}, + }, + HaveV4: true, + } + + // A short time jump (e.g., macOS DarkWake maintenance cycle ~55s) + // with unchanged network state should NOT trigger rebind. + cd, err := NewChangeDelta(s, s, 55*time.Second, true) + if err != nil { + t.Fatalf("NewChangeDelta error: %v", err) + } + if cd.RebindLikelyRequired { + t.Error("RebindLikelyRequired = true for short time jump with unchanged state; want false") + } + if !cd.TimeJumped() { + t.Error("TimeJumped = false; want true") + } + + // A major time jump (>10m) with unchanged state SHOULD trigger rebind, + // because NAT mappings are likely stale. + cd2, err := NewChangeDelta(s, s, 2*time.Hour, true) + if err != nil { + t.Fatalf("NewChangeDelta error: %v", err) + } + if !cd2.RebindLikelyRequired { + t.Error("RebindLikelyRequired = false for major time jump (2h); want true") + } + + // A short time jump with changed state SHOULD trigger rebind. + s2 := &State{ + DefaultRouteInterface: "en0", + Interface: map[string]Interface{ + "en0": {Interface: &net.Interface{ + Name: "en0", + Flags: net.FlagUp | net.FlagBroadcast | net.FlagMulticast | net.FlagRunning, + }}, + }, + InterfaceIPs: map[string][]netip.Prefix{ + "en0": {netip.MustParsePrefix("10.0.0.99/24")}, // IP changed + }, + HaveV4: true, + } + + saveAndRestoreTailscaleIfaceProps(t) + SetTailscaleInterfaceProps("", 0) + + cd3, err := NewChangeDelta(s, s2, 55*time.Second, true) + if err != nil { + t.Fatalf("NewChangeDelta error: %v", err) + } + if !cd3.RebindLikelyRequired { + t.Error("RebindLikelyRequired = false for time jump with changed IP; want true") + } +} + func saveAndRestoreTailscaleIfaceProps(t *testing.T) { t.Helper() index, _ := TailscaleInterfaceIndex() @@ -612,6 +769,71 @@ func TestPrefixesEqual(t *testing.T) { } } +func TestInterfaceDiff(t *testing.T) { + tests := []struct { + name string + s1, s2 *State + wantDiff string // substring expected in diff output; "" means no diff + }{ + { + name: "equal", + s1: &State{HaveV4: true, DefaultRouteInterface: "en0"}, + s2: &State{HaveV4: true, DefaultRouteInterface: "en0"}, + wantDiff: "", + }, + { + name: "flags-changed", + s1: &State{ + DefaultRouteInterface: "en0", + Interface: map[string]Interface{ + "en0": {Interface: &net.Interface{ + Name: "en0", + Flags: net.FlagUp | net.FlagRunning, + }}, + }, + }, + s2: &State{ + DefaultRouteInterface: "en0", + Interface: map[string]Interface{ + "en0": {Interface: &net.Interface{ + Name: "en0", + Flags: net.FlagUp, + }}, + }, + }, + wantDiff: "flags", + }, + { + name: "mtu-changed", + s1: &State{ + Interface: map[string]Interface{ + "en0": {Interface: &net.Interface{Name: "en0", MTU: 1500}}, + }, + }, + s2: &State{ + Interface: map[string]Interface{ + "en0": {Interface: &net.Interface{Name: "en0", MTU: 9000}}, + }, + }, + wantDiff: "MTU", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.s1.InterfaceDiff(tt.s2) + if tt.wantDiff == "" { + if got != "" { + t.Errorf("InterfaceDiff = %q; want empty", got) + } + } else { + if !strings.Contains(got, tt.wantDiff) { + t.Errorf("InterfaceDiff = %q; want substring %q", got, tt.wantDiff) + } + } + }) + } +} + func TestForeachInterface(t *testing.T) { tests := []struct { name string diff --git a/net/netmon/state.go b/net/netmon/state.go index cdfa1d0fbe552..94554dcc39c1e 100644 --- a/net/netmon/state.go +++ b/net/netmon/state.go @@ -443,6 +443,91 @@ func (a Interface) Equal(b Interface) bool { return true } +// InterfaceDiff returns a human-readable summary of the differences between s +// and s2 that would cause Equal to return false. It returns "" if the states +// are equal. This is useful for debugging false link change events where the +// State.String() output looks identical but Equal() returns false because it +// checks fields not shown in String() (like interface Flags, MTU, HardwareAddr). +func (s *State) InterfaceDiff(s2 *State) string { + if s == nil && s2 == nil { + return "" + } + if s == nil { + return "old=nil" + } + if s2 == nil { + return "new=nil" + } + var diffs []string + if s.HaveV6 != s2.HaveV6 { + diffs = append(diffs, fmt.Sprintf("HaveV6: %v->%v", s.HaveV6, s2.HaveV6)) + } + if s.HaveV4 != s2.HaveV4 { + diffs = append(diffs, fmt.Sprintf("HaveV4: %v->%v", s.HaveV4, s2.HaveV4)) + } + if s.IsExpensive != s2.IsExpensive { + diffs = append(diffs, fmt.Sprintf("IsExpensive: %v->%v", s.IsExpensive, s2.IsExpensive)) + } + if s.DefaultRouteInterface != s2.DefaultRouteInterface { + diffs = append(diffs, fmt.Sprintf("DefaultRoute: %q->%q", s.DefaultRouteInterface, s2.DefaultRouteInterface)) + } + if s.HTTPProxy != s2.HTTPProxy { + diffs = append(diffs, fmt.Sprintf("HTTPProxy: %q->%q", s.HTTPProxy, s2.HTTPProxy)) + } + if s.PAC != s2.PAC { + diffs = append(diffs, fmt.Sprintf("PAC: %q->%q", s.PAC, s2.PAC)) + } + if len(s.Interface) != len(s2.Interface) { + diffs = append(diffs, fmt.Sprintf("numInterfaces: %d->%d", len(s.Interface), len(s2.Interface))) + } + if len(s.InterfaceIPs) != len(s2.InterfaceIPs) { + diffs = append(diffs, fmt.Sprintf("numInterfaceIPs: %d->%d", len(s.InterfaceIPs), len(s2.InterfaceIPs))) + } + for iname, i := range s.Interface { + i2, ok := s2.Interface[iname] + if !ok { + diffs = append(diffs, fmt.Sprintf("if %s: removed", iname)) + continue + } + if !i.Equal(i2) { + if i.Interface != nil && i2.Interface != nil { + if i.Flags != i2.Flags { + diffs = append(diffs, fmt.Sprintf("if %s flags: %v->%v", iname, i.Flags, i2.Flags)) + } + if i.MTU != i2.MTU { + diffs = append(diffs, fmt.Sprintf("if %s MTU: %d->%d", iname, i.MTU, i2.MTU)) + } + if i.Index != i2.Index { + diffs = append(diffs, fmt.Sprintf("if %s index: %d->%d", iname, i.Index, i2.Index)) + } + if !bytes.Equal([]byte(i.HardwareAddr), []byte(i2.HardwareAddr)) { + diffs = append(diffs, fmt.Sprintf("if %s hwaddr: %v->%v", iname, i.HardwareAddr, i2.HardwareAddr)) + } + } + if i.Desc != i2.Desc { + diffs = append(diffs, fmt.Sprintf("if %s desc: %q->%q", iname, i.Desc, i2.Desc)) + } + } + } + for iname := range s2.Interface { + if _, ok := s.Interface[iname]; !ok { + diffs = append(diffs, fmt.Sprintf("if %s: added", iname)) + } + } + for iname, vv := range s.InterfaceIPs { + vv2 := s2.InterfaceIPs[iname] + if !slices.Equal(vv, vv2) { + diffs = append(diffs, fmt.Sprintf("ips %s: %v->%v", iname, vv, vv2)) + } + } + for iname := range s2.InterfaceIPs { + if _, ok := s.InterfaceIPs[iname]; !ok { + diffs = append(diffs, fmt.Sprintf("ips %s: added %v", iname, s2.InterfaceIPs[iname])) + } + } + return strings.Join(diffs, "; ") +} + func (s *State) HasPAC() bool { return s != nil && s.PAC != "" } // AnyInterfaceUp reports whether any interface seems like it has Internet access. @@ -812,11 +897,8 @@ func (m *Monitor) HasCGNATInterface() (bool, error) { if hasCGNATInterface || !i.IsUp() || isTailscaleInterface(i.Name, pfxs) { return } - for _, pfx := range pfxs { - if cgnatRange.Overlaps(pfx) { - hasCGNATInterface = true - break - } + if slices.ContainsFunc(pfxs, cgnatRange.Overlaps) { + hasCGNATInterface = true } }) if err != nil { diff --git a/net/netns/netns.go b/net/netns/netns.go index 5d692c787eae8..fe7ff4dcbadd8 100644 --- a/net/netns/netns.go +++ b/net/netns/netns.go @@ -46,6 +46,18 @@ func SetBindToInterfaceByRoute(logf logger.Logf, v bool) { } } +// When true, disableAndroidBindToActiveNetwork skips binding sockets to the currently +// active network on Android. +var disableAndroidBindToActiveNetwork atomic.Bool + +// SetDisableAndroidBindToActiveNetwork disables the default behavior of binding +// sockets to the currently active network on Android. +func SetDisableAndroidBindToActiveNetwork(logf logger.Logf, v bool) { + if runtime.GOOS == "android" && disableAndroidBindToActiveNetwork.Swap(v) != v { + logf("netns: disableAndroidBindToActiveNetwork changed to %v", v) + } +} + var disableBindConnToInterface atomic.Bool // SetDisableBindConnToInterface disables the (normal) behavior of binding diff --git a/net/netns/netns_android.go b/net/netns/netns_android.go index e747f61f40e50..7c5fe3214dcbf 100644 --- a/net/netns/netns_android.go +++ b/net/netns/netns_android.go @@ -17,6 +17,9 @@ import ( var ( androidProtectFuncMu sync.Mutex androidProtectFunc func(fd int) error + + androidBindToNetworkFuncMu sync.Mutex + androidBindToNetworkFunc func(fd int) error ) // UseSocketMark reports whether SO_MARK is in use. Android does not use SO_MARK. @@ -50,6 +53,14 @@ func SetAndroidProtectFunc(f func(fd int) error) { androidProtectFunc = f } +// SetAndroidBindToNetworkFunc registers a func provided by Android that binds +// the socket FD to the currently selected underlying network. +func SetAndroidBindToNetworkFunc(f func(fd int) error) { + androidBindToNetworkFuncMu.Lock() + defer androidBindToNetworkFuncMu.Unlock() + androidBindToNetworkFunc = f +} + func control(logger.Logf, *netmon.Monitor) func(network, address string, c syscall.RawConn) error { return controlC } @@ -60,14 +71,36 @@ func control(logger.Logf, *netmon.Monitor) func(network, address string, c sysca // and net.ListenConfig.Control. func controlC(network, address string, c syscall.RawConn) error { var sockErr error + err := c.Control(func(fd uintptr) { + fdInt := int(fd) + + // Protect from VPN loops androidProtectFuncMu.Lock() - f := androidProtectFunc + pf := androidProtectFunc androidProtectFuncMu.Unlock() - if f != nil { - sockErr = f(int(fd)) + if pf != nil { + if err := pf(fdInt); err != nil { + sockErr = err + return + } + } + + if disableAndroidBindToActiveNetwork.Load() { + return + } + + androidBindToNetworkFuncMu.Lock() + bf := androidBindToNetworkFunc + androidBindToNetworkFuncMu.Unlock() + if bf != nil { + if err := bf(fdInt); err != nil { + sockErr = err + return + } } }) + if err != nil { return fmt.Errorf("RawConn.Control on %T: %w", c, err) } diff --git a/net/netutil/routes.go b/net/netutil/routes.go index c8212b9af66dd..26f2de97c5767 100644 --- a/net/netutil/routes.go +++ b/net/netutil/routes.go @@ -41,8 +41,8 @@ func CalcAdvertiseRoutes(advertiseRoutes string, advertiseDefaultRoute bool) ([] routeMap := map[netip.Prefix]bool{} if advertiseRoutes != "" { var default4, default6 bool - advroutes := strings.Split(advertiseRoutes, ",") - for _, s := range advroutes { + advroutes := strings.SplitSeq(advertiseRoutes, ",") + for s := range advroutes { ipp, err := netip.ParsePrefix(s) if err != nil { return nil, fmt.Errorf("%q is not a valid IP address or CIDR prefix", s) diff --git a/net/packet/geneve_test.go b/net/packet/geneve_test.go index bd673cd0d963a..43a64efde0e80 100644 --- a/net/packet/geneve_test.go +++ b/net/packet/geneve_test.go @@ -9,7 +9,6 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "tailscale.com/types/ptr" ) func TestGeneveHeader(t *testing.T) { @@ -47,22 +46,22 @@ func TestVirtualNetworkID(t *testing.T) { }, { "Set 0", - ptr.To(uint32(0)), + new(uint32(0)), 0, }, { "Set 1", - ptr.To(uint32(1)), + new(uint32(1)), 1, }, { "Set math.MaxUint32", - ptr.To(uint32(math.MaxUint32)), + new(uint32(math.MaxUint32)), 1<<24 - 1, }, { "Set max 3-byte value", - ptr.To(uint32(1<<24 - 1)), + new(uint32(1<<24 - 1)), 1<<24 - 1, }, } diff --git a/net/packet/icmp6_test.go b/net/packet/icmp6_test.go index 0348824b62296..51de86a4a2d01 100644 --- a/net/packet/icmp6_test.go +++ b/net/packet/icmp6_test.go @@ -57,7 +57,7 @@ func TestICMPv6Checksum(t *testing.T) { "\x00\x00\x00\x00\x00\x00\x20\x0e\x80\x00\x4a\x9a\x2e\xea\x00\x02" + "\x61\xb1\x9e\xad\x00\x06\x45\xaa" // The packet that we'd originally generated incorrectly, but with the checksum - // bytes fixed per WireShark's correct calculation: + // bytes fixed per Wireshark's correct calculation: const wantRes = "\x60\x00\xf8\xff\x00\x10\x3a\x40\x26\x07\xf8\xb0\x40\x0a\x08\x07" + "\x00\x00\x00\x00\x00\x00\x20\x0e\xfd\x7a\x11\x5c\xa1\xe0\xab\x12" + "\x48\x43\xcd\x96\x62\x7b\x65\x28\x81\x00\x49\x9a\x2e\xea\x00\x02" + diff --git a/net/ping/ping.go b/net/ping/ping.go index de79da51c5c48..42d381c7311cb 100644 --- a/net/ping/ping.go +++ b/net/ping/ping.go @@ -29,8 +29,10 @@ import ( ) const ( - v4Type = "ip4:icmp" - v6Type = "ip6:icmp" + v4Type = "ip4:icmp" + v6Type = "ip6:icmp" + v4UDPType = "udp4" // unprivileged datagram-oriented ICMPv4 + v6UDPType = "udp6" // unprivileged datagram-oriented ICMPv6 ) type response struct { @@ -54,12 +56,30 @@ type ListenPacketer interface { // A new instance should be created for each concurrent set of ping requests; // this type should not be reused. type Pinger struct { + // options that must be set before the first call to Send + + // Unprivileged, when set, makes the Pinger use non-privileged + // datagram-oriented ICMP sockets ("udp4"/"udp6") opened via + // golang.org/x/net/icmp.ListenPacket instead of raw ICMP sockets + // ("ip4:icmp"/"ip6:icmp") opened via the configured ListenPacketer. + // + // Unprivileged mode is supported on macOS, iOS, and Linux (subject to + // the /proc/sys/net/ipv4/ping_group_range sysctl). When set, the + // ListenPacketer passed to New is ignored and the kernel rewrites the + // outgoing ICMP echo ID to match the socket; replies are matched by + // sequence number and echo data only. + // + // Must be set before the first call to Send. + Unprivileged bool + + Verbose bool // verbose logging + Logf logger.Logf // optional logging function; if nil, logs to the standard logger + lp ListenPacketer // closed guards against send incrementing the waitgroup concurrently with close. - closed atomic.Bool - Logf logger.Logf - Verbose bool + closed atomic.Bool + timeNow func() time.Time id uint16 // uint16 per RFC 792 wg sync.WaitGroup @@ -95,7 +115,17 @@ func (p *Pinger) mkconn(ctx context.Context, typ, addr string) (net.PacketConn, return nil, net.ErrClosed } - c, err := p.lp.ListenPacket(ctx, typ, addr) + var c net.PacketConn + var err error + if p.Unprivileged { + // icmp.ListenPacket on "udp4"/"udp6" opens a datagram-oriented + // ICMP socket that does not require elevated privileges. The + // returned *icmp.PacketConn implements net.PacketConn and, on + // Darwin/iOS, strips the IPv4 header on read via IP_STRIPHDR. + c, err = icmp.ListenPacket(typ, addr) + } else { + c, err = p.lp.ListenPacket(ctx, typ, addr) + } if err != nil { return nil, err } @@ -125,7 +155,7 @@ func (p *Pinger) getConn(ctx context.Context, typ string) (net.PacketConn, error } var addr = "0.0.0.0" - if typ == v6Type { + if typ == v6Type || typ == v6UDPType { addr = "::" } c, err := p.mkconn(ctx, typ, addr) @@ -216,9 +246,9 @@ func (p *Pinger) handleResponse(buf []byte, now time.Time, typ string) { // and IPv6. var icmpType icmp.Type switch typ { - case v4Type: + case v4Type, v4UDPType: icmpType = ipv4.ICMPTypeEchoReply - case v6Type: + case v6Type, v6UDPType: icmpType = ipv6.ICMPTypeEchoReply default: p.vlogf("handleResponse: unknown icmp.Type") @@ -243,7 +273,10 @@ func (p *Pinger) handleResponse(buf []byte, now time.Time, typ string) { } // We assume we sent this if the ID in the response is ours. - if uint16(resp.ID) != p.id { + // In unprivileged ICMP DGRAM mode the kernel rewrites the ID to match + // the socket, so the value we set on the way out is not what comes + // back; rely on sequence and data matching instead. + if !p.Unprivileged && uint16(resp.ID) != p.id { p.vlogf("handleResponse: wanted ID=%d; got %d", p.id, resp.ID) return } @@ -294,14 +327,30 @@ func (p *Pinger) Send(ctx context.Context, dest net.Addr, data []byte) (time.Dur } if ap.Is6() { icmpType = ipv6.ICMPTypeEchoRequest - conn, err = p.getConn(ctx, v6Type) + typ := v6Type + if p.Unprivileged { + typ = v6UDPType + } + conn, err = p.getConn(ctx, typ) } else { - conn, err = p.getConn(ctx, v4Type) + typ := v4Type + if p.Unprivileged { + typ = v4UDPType + } + conn, err = p.getConn(ctx, typ) } if err != nil { return 0, err } + // In unprivileged ICMP DGRAM mode (icmp.ListenPacket on "udp4"/"udp6"), + // the kernel requires a *net.UDPAddr destination for WriteTo even though + // the wire packet is ICMP. + writeDst := dest + if p.Unprivileged { + writeDst = &net.UDPAddr{IP: ap.AsSlice(), Zone: ap.Zone()} + } + m := icmp.Message{ Type: icmpType, Code: 0, @@ -324,7 +373,7 @@ func (p *Pinger) Send(ctx context.Context, dest net.Addr, data []byte) (time.Dur p.mu.Unlock() start := p.timeNow() - n, err := conn.WriteTo(b, dest) + n, err := conn.WriteTo(b, writeDst) if err != nil { return 0, err } else if n != len(b) { diff --git a/net/socks5/socks5.go b/net/socks5/socks5.go index 729fc8e882cf1..f67dc1ecc202a 100644 --- a/net/socks5/socks5.go +++ b/net/socks5/socks5.go @@ -21,6 +21,7 @@ import ( "io" "log" "net" + "slices" "strconv" "time" @@ -488,10 +489,8 @@ func parseClientGreeting(r io.Reader, authMethod byte) error { if err != nil { return fmt.Errorf("could not read methods") } - for _, m := range methods { - if m == authMethod { - return nil - } + if slices.Contains(methods, authMethod) { + return nil } return fmt.Errorf("no acceptable auth methods") } diff --git a/net/socks5/socks5_test.go b/net/socks5/socks5_test.go index 9fbc11f8c0dfb..84ef4be7bc651 100644 --- a/net/socks5/socks5_test.go +++ b/net/socks5/socks5_test.go @@ -180,11 +180,11 @@ func TestUDP(t *testing.T) { const echoServerNumber = 3 echoServerListener := make([]net.PacketConn, echoServerNumber) - for i := 0; i < echoServerNumber; i++ { + for i := range echoServerNumber { echoServerListener[i] = newUDPEchoServer() } defer func() { - for i := 0; i < echoServerNumber; i++ { + for i := range echoServerNumber { _ = echoServerListener[i].Close() } }() @@ -222,7 +222,7 @@ func TestUDP(t *testing.T) { if err != nil { t.Fatal(err) } - _, err = conn.Write(append([]byte{socks5Version, byte(udpAssociate), 0x00}, targetAddrPkt...)) // client reqeust + _, err = conn.Write(append([]byte{socks5Version, byte(udpAssociate), 0x00}, targetAddrPkt...)) // client request if err != nil { t.Fatal(err) } @@ -277,10 +277,10 @@ func TestUDP(t *testing.T) { } defer socks5UDPConn.Close() - for i := 0; i < echoServerNumber; i++ { + for i := range echoServerNumber { port := echoServerListener[i].LocalAddr().(*net.UDPAddr).Port addr := socksAddr{addrType: ipv4, addr: "127.0.0.1", port: uint16(port)} - requestBody := []byte(fmt.Sprintf("Test %d", i)) + requestBody := fmt.Appendf(nil, "Test %d", i) responseBody := sendUDPAndWaitResponse(socks5UDPConn, addr, requestBody) if !bytes.Equal(requestBody, responseBody) { t.Fatalf("got: %q want: %q", responseBody, requestBody) diff --git a/net/speedtest/speedtest_test.go b/net/speedtest/speedtest_test.go index 1fbd0915b219f..eb851eb26e332 100644 --- a/net/speedtest/speedtest_test.go +++ b/net/speedtest/speedtest_test.go @@ -47,7 +47,7 @@ func TestDownload(t *testing.T) { // ensure that the test returns an appropriate number of Result structs expectedLen := int(DefaultDuration.Seconds()) + 1 - t.Run("download test", func(t *testing.T) { + t.Run("download-test", func(t *testing.T) { // conduct a download test results, err := RunClient(Download, DefaultDuration, serverIP) @@ -65,7 +65,7 @@ func TestDownload(t *testing.T) { } }) - t.Run("upload test", func(t *testing.T) { + t.Run("upload-test", func(t *testing.T) { // conduct an upload test results, err := RunClient(Upload, DefaultDuration, serverIP) diff --git a/net/stun/stun_test.go b/net/stun/stun_test.go index 7f754324e7597..c26a6a5c7320e 100644 --- a/net/stun/stun_test.go +++ b/net/stun/stun_test.go @@ -60,7 +60,7 @@ var responseTests = []struct { wantPort: 59029, }, { - name: "stun.sipgate.net:10000", + name: "stun-sipgate-net-10000", data: []byte{ 0x01, 0x01, 0x00, 0x44, 0x21, 0x12, 0xa4, 0x42, 0x48, 0x2e, 0xb6, 0x47, 0x15, 0xe8, 0xb2, 0x8e, @@ -82,7 +82,7 @@ var responseTests = []struct { wantPort: 58539, }, { - name: "stun.powervoip.com:3478", + name: "stun-powervoip-com-3478", data: []byte{ 0x01, 0x01, 0x00, 0x24, 0x21, 0x12, 0xa4, 0x42, 0x7e, 0x57, 0x96, 0x68, 0x29, 0xf4, 0x44, 0x60, @@ -100,7 +100,7 @@ var responseTests = []struct { wantPort: 59859, }, { - name: "in-process pion server", + name: "in-process-pion-server", data: []byte{ 0x01, 0x01, 0x00, 0x24, 0x21, 0x12, 0xa4, 0x42, 0xeb, 0xc2, 0xd3, 0x6e, 0xf4, 0x71, 0x21, 0x7c, @@ -119,7 +119,7 @@ var responseTests = []struct { wantPort: 61300, }, { - name: "stuntman-server ipv6", + name: "stuntman-server-ipv6", data: []byte{ 0x01, 0x01, 0x00, 0x48, 0x21, 0x12, 0xa4, 0x42, 0x06, 0xf5, 0x66, 0x85, 0xd2, 0x8a, 0xf3, 0xe6, diff --git a/net/stunserver/stunserver_test.go b/net/stunserver/stunserver_test.go index c96aea4d15973..f9efe21f30494 100644 --- a/net/stunserver/stunserver_test.go +++ b/net/stunserver/stunserver_test.go @@ -60,8 +60,7 @@ func TestSTUNServer(t *testing.T) { func BenchmarkServerSTUN(b *testing.B) { b.ReportAllocs() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := b.Context() s := New(ctx) s.Listen("localhost:0") diff --git a/net/tlsdial/tlsdial.go b/net/tlsdial/tlsdial.go index ffc8c90a80f96..417c925b77161 100644 --- a/net/tlsdial/tlsdial.go +++ b/net/tlsdial/tlsdial.go @@ -59,15 +59,26 @@ var mitmBlockWarnable = health.Register(&health.Warnable{ // the baked-in LetsEncrypt roots as a fallback validation method. // // If base is non-nil, it's cloned as the base config before -// being configured and returned. +// being configured and returned. If base.RootCAs is non-nil, it is +// used as an additional set of trusted roots (after system roots, +// before baked-in LetsEncrypt roots). This is used on Android to +// trust user-installed CA certificates that Go's crypto/x509 +// does not see. +// // If ht is non-nil, it's used to report health errors. func Config(ht *health.Tracker, base *tls.Config) *tls.Config { + var extraRoots *x509.CertPool + if base != nil { + extraRoots = base.RootCAs + } + var conf *tls.Config if base == nil { conf = new(tls.Config) } else { conf = base.Clone() } + conf.RootCAs = nil // we do our own verification in VerifyConnection // Note: we do NOT set conf.ServerName here (as we accidentally did // previously), as this path is also used when dialing an HTTPS proxy server @@ -77,7 +88,7 @@ func Config(ht *health.Tracker, base *tls.Config) *tls.Config { if buildfeatures.HasDebug { // If SSLKEYLOGFILE is set, it's a file to which we write our TLS private keys - // in a way that WireShark can read. + // in a way that Wireshark can read. // // See https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/Key_Log_Format if n := os.Getenv("SSLKEYLOGFILE"); n != "" { @@ -165,7 +176,26 @@ func Config(ht *health.Tracker, base *tls.Config) *tls.Config { if debug() { log.Printf("tlsdial(sys %q): %v", dialedHost, errSys) } - if !buildfeatures.HasBakedRoots || (errSys == nil && !debug()) { + if errSys == nil && !debug() { + return nil + } + + // If extra roots were provided (e.g. user-installed CAs on + // Android), try those next. + if extraRoots != nil { + opts.Roots = extraRoots + _, errExtra := cs.PeerCertificates[0].Verify(opts) + if debug() { + log.Printf("tlsdial(extra %q): %v", dialedHost, errExtra) + } + if errExtra == nil { + atomic.AddInt32(&counterFallbackOK, 1) + return nil + } + opts.Roots = nil // reset for baked roots check + } + + if !buildfeatures.HasBakedRoots { return errSys } @@ -178,7 +208,11 @@ func Config(ht *health.Tracker, base *tls.Config) *tls.Config { } else if bakedErr != nil { if _, loaded := tlsdialWarningPrinted.LoadOrStore(dialedHost, true); !loaded { if errSys != nil { - log.Printf("tlsdial: error: server cert for %q failed both system roots & Let's Encrypt root validation", dialedHost) + if extraRoots != nil { + log.Printf("tlsdial: error: server cert for %q failed system roots, extra roots & Let's Encrypt root validation", dialedHost) + } else { + log.Printf("tlsdial: error: server cert for %q failed both system roots & Let's Encrypt root validation", dialedHost) + } } } } @@ -213,6 +247,10 @@ func SetConfigExpectedCert(c *tls.Config, certDNSName string) { c.ServerName = certDNSName return } + + extraRoots := c.RootCAs + c.RootCAs = nil + // Set InsecureSkipVerify to prevent crypto/tls from doing its // own cert verification, but do the same work that it'd do // (but using certDNSName) in the VerifyPeerCertificate hook. @@ -242,7 +280,21 @@ func SetConfigExpectedCert(c *tls.Config, certDNSName string) { if debug() { log.Printf("tlsdial(sys %q/%q): %v", c.ServerName, certDNSName, errSys) } - if !buildfeatures.HasBakedRoots || errSys == nil { + if errSys == nil { + return nil + } + if extraRoots != nil { + opts.Roots = extraRoots + _, errExtra := certs[0].Verify(opts) + if debug() { + log.Printf("tlsdial(extra %q/%q): %v", c.ServerName, certDNSName, errExtra) + } + if errExtra == nil { + return nil + } + opts.Roots = nil + } + if !buildfeatures.HasBakedRoots { return errSys } opts.Roots = bakedroots.Get() diff --git a/net/tsdial/tsdial.go b/net/tsdial/tsdial.go index ebbafa52b01e9..ca08810a3da0e 100644 --- a/net/tsdial/tsdial.go +++ b/net/tsdial/tsdial.go @@ -515,6 +515,33 @@ func (d *Dialer) UserDial(ctx context.Context, network, addr string) (net.Conn, return stdDialer.DialContext(ctx, network, ipp.String()) } +// UserDialPlan resolves addr and reports whether the dialer would +// handle it via Tailscale. If viaTailscale is false, the resolved +// address is not a Tailscale route and the caller may dial it directly. +// +// Warning: there is a TOCTOU race if addr contains a DNS name and the +// caller subsequently passes the same DNS name to [Dialer.UserDial], as DNS +// may resolve differently the second time. Callers who want to only +// dial over Tailscale should call [Dialer.UserDial] with the returned +// ipp.String() (an IP:port) rather than the original DNS name. +func (d *Dialer) UserDialPlan(ctx context.Context, network, addr string) (ipp netip.AddrPort, viaTailscale bool, err error) { + ipp, err = d.userDialResolve(ctx, network, addr) + if err != nil { + return netip.AddrPort{}, false, err + } + if d.UseNetstackForIP != nil && d.UseNetstackForIP(ipp.Addr()) { + return ipp, true, nil + } + if routes := d.routes.Load(); routes != nil { + isTailscaleRoute, _ := routes.Lookup(ipp.Addr()) + return ipp, isTailscaleRoute, nil + } + if version.IsMacGUIVariant() && tsaddr.IsTailscaleIP(ipp.Addr()) { + return ipp, true, nil + } + return ipp, false, nil +} + // dialPeerAPI connects to a Tailscale peer's peerapi over TCP. // // network must a "tcp" type, and addr must be an ip:port. Name resolution diff --git a/net/tsdial/tsdial_test.go b/net/tsdial/tsdial_test.go new file mode 100644 index 0000000000000..92960acbe38b1 --- /dev/null +++ b/net/tsdial/tsdial_test.go @@ -0,0 +1,97 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package tsdial + +import ( + "context" + "net/netip" + "testing" + + "github.com/gaissmai/bart" +) + +func TestUserDialPlan(t *testing.T) { + tests := []struct { + name string + addr string + routes map[netip.Prefix]bool // nil means no routes configured + useNetstackFor func(netip.Addr) bool // nil means not set + wantVia bool + wantAddr netip.AddrPort + }{ + { + name: "loopback_no_routes", + addr: "127.0.0.1:8080", + wantVia: false, + wantAddr: netip.MustParseAddrPort("127.0.0.1:8080"), + }, + { + name: "loopback_v6_no_routes", + addr: "[::1]:8080", + wantVia: false, + wantAddr: netip.MustParseAddrPort("[::1]:8080"), + }, + { + name: "tailscale_ip_in_routes", + addr: "100.64.1.1:22", + routes: map[netip.Prefix]bool{ + netip.MustParsePrefix("100.64.0.0/10"): true, + }, + wantVia: true, + wantAddr: netip.MustParseAddrPort("100.64.1.1:22"), + }, + { + name: "non_tailscale_ip_in_local_routes", + addr: "10.0.0.5:80", + routes: map[netip.Prefix]bool{ + netip.MustParsePrefix("100.64.0.0/10"): true, + netip.MustParsePrefix("10.0.0.0/8"): false, // local route + }, + wantVia: false, + wantAddr: netip.MustParseAddrPort("10.0.0.5:80"), + }, + { + name: "loopback_with_routes_configured", + addr: "127.0.0.1:3000", + routes: map[netip.Prefix]bool{ + netip.MustParsePrefix("100.64.0.0/10"): true, + }, + wantVia: false, + wantAddr: netip.MustParseAddrPort("127.0.0.1:3000"), + }, + { + name: "netstack_for_ip", + addr: "100.100.100.100:53", + useNetstackFor: func(ip netip.Addr) bool { + return ip == netip.MustParseAddr("100.100.100.100") + }, + wantVia: true, + wantAddr: netip.MustParseAddrPort("100.100.100.100:53"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := &Dialer{} + if tt.routes != nil { + rt := &bart.Table[bool]{} + for pfx, v := range tt.routes { + rt.Insert(pfx, v) + } + d.routes.Store(rt) + } + d.UseNetstackForIP = tt.useNetstackFor + + ipp, viaTailscale, err := d.UserDialPlan(context.Background(), "tcp", tt.addr) + if err != nil { + t.Fatalf("UserDialPlan: %v", err) + } + if viaTailscale != tt.wantVia { + t.Errorf("viaTailscale = %v, want %v", viaTailscale, tt.wantVia) + } + if ipp != tt.wantAddr { + t.Errorf("addr = %v, want %v", ipp, tt.wantAddr) + } + }) + } +} diff --git a/net/tshttpproxy/tshttpproxy_synology_test.go b/net/tshttpproxy/tshttpproxy_synology_test.go index a57ac1558d4f4..7360bb6f8fe29 100644 --- a/net/tshttpproxy/tshttpproxy_synology_test.go +++ b/net/tshttpproxy/tshttpproxy_synology_test.go @@ -28,7 +28,7 @@ func TestSynologyProxyFromConfigCached(t *testing.T) { tstest.Replace(t, &synologyProxyConfigPath, filepath.Join(t.TempDir(), "proxy.conf")) - t.Run("no config file", func(t *testing.T) { + t.Run("no-config-file", func(t *testing.T) { if _, err := os.Stat(synologyProxyConfigPath); err == nil { t.Fatalf("%s must not exist for this test", synologyProxyConfigPath) } @@ -52,7 +52,7 @@ func TestSynologyProxyFromConfigCached(t *testing.T) { } }) - t.Run("config file updated", func(t *testing.T) { + t.Run("config-file-updated", func(t *testing.T) { cache.updated = time.Now() cache.httpProxy = nil cache.httpsProxy = nil @@ -84,7 +84,7 @@ https_port=443 } }) - t.Run("config file removed", func(t *testing.T) { + t.Run("config-file-removed", func(t *testing.T) { cache.updated = time.Now() cache.httpProxy = urlMustParse("http://127.0.0.1/") cache.httpsProxy = urlMustParse("http://127.0.0.1/") @@ -108,7 +108,7 @@ https_port=443 } }) - t.Run("picks proxy from request scheme", func(t *testing.T) { + t.Run("picks-proxy-from-request-scheme", func(t *testing.T) { cache.updated = time.Now() cache.httpProxy = nil cache.httpsProxy = nil @@ -164,7 +164,7 @@ func TestSynologyProxiesFromConfig(t *testing.T) { return openReader, openErr }) - t.Run("with config", func(t *testing.T) { + t.Run("with-config", func(t *testing.T) { mc := &mustCloser{Reader: strings.NewReader(` proxy_user=foo proxy_pwd=bar @@ -200,7 +200,7 @@ http_port=80 }) - t.Run("nonexistent config", func(t *testing.T) { + t.Run("nonexistent-config", func(t *testing.T) { openReader = nil openErr = os.ErrNotExist @@ -216,7 +216,7 @@ http_port=80 } }) - t.Run("error opening config", func(t *testing.T) { + t.Run("error-opening-config", func(t *testing.T) { openReader = nil openErr = errors.New("example error") diff --git a/net/tshttpproxy/tshttpproxy_test.go b/net/tshttpproxy/tshttpproxy_test.go index da847429d4bd4..b391c74d89df0 100644 --- a/net/tshttpproxy/tshttpproxy_test.go +++ b/net/tshttpproxy/tshttpproxy_test.go @@ -97,7 +97,7 @@ func TestSetSelfProxy(t *testing.T) { wantHTTPS string }{ { - name: "no self proxy", + name: "no-self-proxy", env: map[string]string{ "HTTP_PROXY": "127.0.0.1:1234", "HTTPS_PROXY": "127.0.0.1:1234", @@ -107,7 +107,7 @@ func TestSetSelfProxy(t *testing.T) { wantHTTPS: "127.0.0.1:1234", }, { - name: "skip proxies", + name: "skip-proxies", env: map[string]string{ "HTTP_PROXY": "127.0.0.1:1234", "HTTPS_PROXY": "127.0.0.1:5678", @@ -117,7 +117,7 @@ func TestSetSelfProxy(t *testing.T) { wantHTTPS: "", // skipped }, { - name: "localhost normalization of env var", + name: "localhost-normalization-of-env-var", env: map[string]string{ "HTTP_PROXY": "localhost:1234", "HTTPS_PROXY": "[::1]:5678", @@ -127,7 +127,7 @@ func TestSetSelfProxy(t *testing.T) { wantHTTPS: "", // skipped }, { - name: "localhost normalization of addr", + name: "localhost-normalization-of-addr", env: map[string]string{ "HTTP_PROXY": "127.0.0.1:1234", "HTTPS_PROXY": "127.0.0.1:1234", @@ -137,7 +137,7 @@ func TestSetSelfProxy(t *testing.T) { wantHTTPS: "", // skipped }, { - name: "no ports", + name: "no-ports", env: map[string]string{ "HTTP_PROXY": "myproxy", "HTTPS_PROXY": "myproxy", diff --git a/net/tstun/wrap.go b/net/tstun/wrap.go index 2f5d8c1d13254..cd75aff5ccffd 100644 --- a/net/tstun/wrap.go +++ b/net/tstun/wrap.go @@ -24,12 +24,14 @@ import ( "go4.org/mem" "tailscale.com/disco" "tailscale.com/envknob" + "tailscale.com/feature" "tailscale.com/feature/buildfeatures" "tailscale.com/net/packet" "tailscale.com/net/packet/checksum" "tailscale.com/net/tsaddr" "tailscale.com/syncs" "tailscale.com/tstime/mono" + "tailscale.com/types/events" "tailscale.com/types/ipproto" "tailscale.com/types/key" "tailscale.com/types/logger" @@ -109,8 +111,7 @@ type Wrapper struct { // you might need to add an align64 field here. lastActivityAtomic mono.Time // time of last send or receive - destIPActivity syncs.AtomicValue[map[netip.Addr]func()] - discoKey syncs.AtomicValue[key.DiscoPublic] + discoKey syncs.AtomicValue[key.DiscoPublic] // timeNow, if non-nil, will be used to obtain the current time. timeNow func() time.Time @@ -220,7 +221,11 @@ type Wrapper struct { metrics *metrics eventClient *eventbus.Client - discoKeyAdvertisementPub *eventbus.Publisher[DiscoKeyAdvertisement] + discoKeyAdvertisementPub *eventbus.Publisher[events.DiscoKeyAdvertisement] + + // tunDevStatsCloser closes TUN device stats polling. It may be nil if + // [HookPollTUNDevStats] is unset, or the hook func returned an error. + tunDevStatsCloser io.Closer } type metrics struct { @@ -295,8 +300,18 @@ func wrap(logf logger.Logf, tdev tun.Device, isTAP bool, m *usermetric.Registry, metrics: registerMetrics(m), } + if buildfeatures.HasTUNDevStats { + if f, ok := HookPollTUNDevStats.GetOk(); ok { + closer, err := f(tdev) + if err != nil { + w.logf("error initializing tun dev stats polling: %v", err) + } + w.tunDevStatsCloser = closer + } + } + w.eventClient = bus.Client("net.tstun") - w.discoKeyAdvertisementPub = eventbus.Publish[DiscoKeyAdvertisement](w.eventClient) + w.discoKeyAdvertisementPub = eventbus.Publish[events.DiscoKeyAdvertisement](w.eventClient) w.vectorBuffer = make([][]byte, tdev.BatchSize()) for i := range w.vectorBuffer { @@ -312,6 +327,9 @@ func wrap(logf logger.Logf, tdev tun.Device, isTAP bool, m *usermetric.Registry, return w } +// HookPollTUNDevStats is the hook maybe set by feature/tundevstats. +var HookPollTUNDevStats feature.Hook[func(dev tun.Device) (io.Closer, error)] + // now returns the current time, either by calling t.timeNow if set or time.Now // if not. func (t *Wrapper) now() time.Time { @@ -321,16 +339,6 @@ func (t *Wrapper) now() time.Time { return time.Now() } -// SetDestIPActivityFuncs sets a map of funcs to run per packet -// destination (the map keys). -// -// The map ownership passes to the Wrapper. It must be non-nil. -func (t *Wrapper) SetDestIPActivityFuncs(m map[netip.Addr]func()) { - if buildfeatures.HasLazyWG { - t.destIPActivity.Store(m) - } -} - // SetDiscoKey sets the current discovery key. // // It is only used for filtering out bogus traffic when network @@ -373,6 +381,9 @@ func (t *Wrapper) Close() error { t.outboundMu.Unlock() err = t.tdev.Close() t.eventClient.Close() + if t.tunDevStatsCloser != nil { + t.tunDevStatsCloser.Close() + } }) return err } @@ -512,8 +523,9 @@ func (t *Wrapper) injectOutbound(r tunInjectedRead) { if t.outboundClosed { return } - t.vectorOutbound <- tunVectorReadResult{ - injected: r, + select { + case t.vectorOutbound <- tunVectorReadResult{injected: r}: + case <-t.closed: } } @@ -524,7 +536,10 @@ func (t *Wrapper) sendVectorOutbound(r tunVectorReadResult) { if t.outboundClosed { return } - t.vectorOutbound <- r + select { + case t.vectorOutbound <- r: + case <-t.closed: + } } // snat does SNAT on p if the destination address requires a different source address. @@ -971,13 +986,6 @@ func (t *Wrapper) Read(buffs [][]byte, sizes []int, offset int) (int, error) { for _, data := range res.data { p.Decode(data[res.dataOffset:]) - if buildfeatures.HasLazyWG { - if m := t.destIPActivity.Load(); m != nil { - if fn := m[p.Dst.Addr()]; fn != nil { - fn() - } - } - } if buildfeatures.HasCapture && captHook != nil { captHook(packet.FromLocal, t.now(), p.Buffer(), p.CaptureMeta) } @@ -1110,14 +1118,6 @@ func (t *Wrapper) injectedRead(res tunInjectedRead, outBuffs [][]byte, sizes []i pc.snat(p) invertGSOChecksum(pkt, gso) - if buildfeatures.HasLazyWG { - if m := t.destIPActivity.Load(); m != nil { - if fn := m[p.Dst.Addr()]; fn != nil { - fn() - } - } - } - if res.packet != nil { var gsoOptions tun.GSOOptions gsoOptions, err = stackGSOToTunGSO(pkt, gso) @@ -1140,13 +1140,6 @@ func (t *Wrapper) injectedRead(res tunInjectedRead, outBuffs [][]byte, sizes []i return n, err } -// DiscoKeyAdvertisement is a TSMP message used for distributing disco keys. -// This struct is used an an event on the [eventbus.Bus]. -type DiscoKeyAdvertisement struct { - Src netip.Addr // Src field is populated by the IP header of the packet, not from the payload itself. - Key key.DiscoPublic -} - func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook packet.CaptureCallback, pc *peerConfigTable, gro *gro.GRO) (filter.Response, *gro.GRO) { if captHook != nil { captHook(packet.FromPeer, t.now(), p.Buffer(), p.CaptureMeta) @@ -1158,8 +1151,8 @@ func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook pa t.injectOutboundPong(p, pingReq) return filter.DropSilently, gro } else if discoKeyAdvert, ok := p.AsTSMPDiscoAdvertisement(); ok { - if buildfeatures.HasCacheNetMap && envknob.Bool("TS_USE_CACHED_NETMAP") { - t.discoKeyAdvertisementPub.Publish(DiscoKeyAdvertisement{ + if buildfeatures.HasCacheNetMap && envknob.BoolDefaultTrue("TS_USE_CACHED_NETMAP") { + t.discoKeyAdvertisementPub.Publish(events.DiscoKeyAdvertisement{ Src: discoKeyAdvert.Src, Key: discoKeyAdvert.Key, }) @@ -1406,11 +1399,11 @@ func (t *Wrapper) InjectInboundPacketBuffer(pkt *netstack_PacketBuffer, buffs [] return err } } - for i := 0; i < n; i++ { + for i := range n { buffs[i] = buffs[i][:PacketStartOffset+sizes[i]] } defer func() { - for i := 0; i < n; i++ { + for i := range n { buffs[i] = buffs[i][:cap(buffs[i])] } }() diff --git a/net/tstun/wrap_test.go b/net/tstun/wrap_test.go index 1744fc30266a9..57b300513fec8 100644 --- a/net/tstun/wrap_test.go +++ b/net/tstun/wrap_test.go @@ -34,7 +34,6 @@ import ( "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/netlogtype" - "tailscale.com/types/ptr" "tailscale.com/types/views" "tailscale.com/util/eventbus" "tailscale.com/util/eventbus/eventbustest" @@ -96,7 +95,7 @@ func tcp4syn(src, dst string, sport, dport uint16) []byte { func nets(nets ...string) (ret []netip.Prefix) { for _, s := range nets { - if i := strings.IndexByte(s, '/'); i == -1 { + if found := strings.Contains(s, "/"); !found { ip, err := netip.ParseAddr(s) if err != nil { panic(err) @@ -123,13 +122,13 @@ func ports(s string) filter.PortRange { } var fs, ls string - i := strings.IndexByte(s, '-') - if i == -1 { + before, after, ok := strings.Cut(s, "-") + if !ok { fs = s ls = fs } else { - fs = s[:i] - ls = s[i+1:] + fs = before + ls = after } first, err := strconv.ParseInt(fs, 10, 16) if err != nil { @@ -655,9 +654,9 @@ func TestPeerCfg_NAT(t *testing.T) { }, } if masqIP.Is4() { - p.V4MasqAddr = ptr.To(masqIP) + p.V4MasqAddr = new(masqIP) } else { - p.V6MasqAddr = ptr.To(masqIP) + p.V6MasqAddr = new(masqIP) } p.AllowedIPs = append(p.AllowedIPs, otherAllowedIPs...) return p diff --git a/net/udprelay/endpoint/endpoint_test.go b/net/udprelay/endpoint/endpoint_test.go index eaef289de6725..23fd88ad3e0bd 100644 --- a/net/udprelay/endpoint/endpoint_test.go +++ b/net/udprelay/endpoint/endpoint_test.go @@ -28,32 +28,32 @@ func TestServerEndpointJSONUnmarshal(t *testing.T) { wantErr: false, }, { - name: "invalid ServerDisco", + name: "invalid-ServerDisco", json: []byte(`{"ServerDisco":"1","LamportID":18446744073709551615,"AddrPorts":["127.0.0.1:1","127.0.0.2:2"],"VNI":16777215,"BindLifetime":"30s","SteadyStateLifetime":"5m0s"}`), wantErr: true, }, { - name: "invalid LamportID", + name: "invalid-LamportID", json: []byte(`{"ServerDisco":"discokey:003cd7453e04a653eb0e7a18f206fc353180efadb2facfd05ebd6982a1392c7f","LamportID":1.1,"AddrPorts":["127.0.0.1:1","127.0.0.2:2"],"VNI":16777215,"BindLifetime":"30s","SteadyStateLifetime":"5m0s"}`), wantErr: true, }, { - name: "invalid AddrPorts", + name: "invalid-AddrPorts", json: []byte(`{"ServerDisco":"discokey:003cd7453e04a653eb0e7a18f206fc353180efadb2facfd05ebd6982a1392c7f","LamportID":18446744073709551615,"AddrPorts":["127.0.0.1.1:1","127.0.0.2:2"],"VNI":16777215,"BindLifetime":"30s","SteadyStateLifetime":"5m0s"}`), wantErr: true, }, { - name: "invalid VNI", + name: "invalid-VNI", json: []byte(`{"ServerDisco":"discokey:003cd7453e04a653eb0e7a18f206fc353180efadb2facfd05ebd6982a1392c7f","LamportID":18446744073709551615,"AddrPorts":["127.0.0.1:1","127.0.0.2:2"],"VNI":18446744073709551615,"BindLifetime":"30s","SteadyStateLifetime":"5m0s"}`), wantErr: true, }, { - name: "invalid BindLifetime", + name: "invalid-BindLifetime", json: []byte(`{"ServerDisco":"discokey:003cd7453e04a653eb0e7a18f206fc353180efadb2facfd05ebd6982a1392c7f","LamportID":18446744073709551615,"AddrPorts":["127.0.0.1:1","127.0.0.2:2"],"VNI":16777215,"BindLifetime":"5","SteadyStateLifetime":"5m0s"}`), wantErr: true, }, { - name: "invalid SteadyStateLifetime", + name: "invalid-SteadyStateLifetime", json: []byte(`{"ServerDisco":"discokey:003cd7453e04a653eb0e7a18f206fc353180efadb2facfd05ebd6982a1392c7f","LamportID":18446744073709551615,"AddrPorts":["127.0.0.1:1","127.0.0.2:2"],"VNI":16777215,"BindLifetime":"30s","SteadyStateLifetime":"5"}`), wantErr: true, }, @@ -79,7 +79,7 @@ func TestServerEndpointJSONMarshal(t *testing.T) { serverEndpoint ServerEndpoint }{ { - name: "valid roundtrip", + name: "valid-roundtrip", serverEndpoint: ServerEndpoint{ ServerDisco: key.NewDisco().Public(), LamportID: uint64(math.MaxUint64), diff --git a/net/udprelay/server.go b/net/udprelay/server.go index 03d8e3dc3050d..3b0f729897ca8 100644 --- a/net/udprelay/server.go +++ b/net/udprelay/server.go @@ -1,9 +1,9 @@ // Copyright (c) Tailscale Inc & contributors // SPDX-License-Identifier: BSD-3-Clause -// Package udprelay contains constructs for relaying Disco and WireGuard packets -// between Tailscale clients over UDP. This package is currently considered -// experimental. +// Package udprelay contains a relay server implementation for relaying Disco +// and WireGuard packets between Tailscale clients over UDP. This relay +// functionality is also known as Tailscale Peer Relays. package udprelay import ( @@ -689,7 +689,7 @@ func (s *Server) bindSockets(desiredPort uint16) error { break SocketsLoop } } - pc := batching.TryUpgradeToConn(uc, network, batching.IdealBatchSize) + pc := batching.TryUpgradeToConn(uc, network, batching.IdealBatchSize, "udprelay_rxq_overflows") bc, ok := pc.(batching.Conn) if !ok { bc = &singlePacketConn{uc} @@ -977,7 +977,7 @@ func (e ErrServerNotReady) Error() string { // For now, we favor simplicity and reducing VNI re-use over more complex // ephemeral port (VNI) selection algorithms. func (s *Server) getNextVNILocked() (uint32, error) { - for i := uint32(0); i < totalPossibleVNI; i++ { + for range totalPossibleVNI { vni := s.nextVNI if vni == maxVNI { s.nextVNI = minVNI diff --git a/net/udprelay/server_test.go b/net/udprelay/server_test.go index 66de0d88a7d0d..00b9c2423bd3c 100644 --- a/net/udprelay/server_test.go +++ b/net/udprelay/server_test.go @@ -196,15 +196,15 @@ func TestServer(t *testing.T) { forceClientsMixedAF bool }{ { - name: "over ipv4", + name: "over-ipv4", staticAddrs: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, }, { - name: "over ipv6", + name: "over-ipv6", staticAddrs: []netip.Addr{netip.MustParseAddr("::1")}, }, { - name: "mixed address families", + name: "mixed-address-families", staticAddrs: []netip.Addr{netip.MustParseAddr("127.0.0.1"), netip.MustParseAddr("::1")}, forceClientsMixedAF: true, }, @@ -265,7 +265,7 @@ func TestServer(t *testing.T) { tcB := newTestClient(t, endpoint.VNI, tcBServerEndpointAddr, discoB, discoA.Public(), endpoint.ServerDisco) defer tcB.close() - for i := 0; i < 2; i++ { + for range 2 { // We handshake both clients twice to guarantee server-side // packet reading goroutines, which are independent across // address families, have seen an answer from both clients @@ -345,7 +345,7 @@ func TestServer_getNextVNILocked(t *testing.T) { s := &Server{ nextVNI: minVNI, } - for i := uint64(0); i < uint64(totalPossibleVNI); i++ { + for range uint64(totalPossibleVNI) { vni, err := s.getNextVNILocked() if err != nil { // using quicktest here triples test time t.Fatal(err) diff --git a/pkgdoc_test.go b/pkgdoc_test.go index 60b2d4856d6c7..d0f0d66bd3a1d 100644 --- a/pkgdoc_test.go +++ b/pkgdoc_test.go @@ -4,6 +4,7 @@ package tailscaleroot import ( + "go/ast" "go/parser" "go/token" "os" @@ -13,6 +14,17 @@ import ( "testing" ) +func hasIgnoreBuildTag(f *ast.File) bool { + for _, cg := range f.Comments { + for _, c := range cg.List { + if c.Text == "//go:build ignore" { + return true + } + } + } + return false +} + func TestPackageDocs(t *testing.T) { switch runtime.GOOS { case "darwin", "linux": @@ -26,8 +38,11 @@ func TestPackageDocs(t *testing.T) { if err != nil { return err } - if fi.Mode().IsDir() && path == ".git" { - return filepath.SkipDir // No documentation lives in .git + if fi.Mode().IsDir() && path != "." && strings.HasPrefix(filepath.Base(path), ".") { + return filepath.SkipDir // No documentation lives in dot directories (.git, .claude, etc) + } + if fi.Mode().IsDir() && filepath.Base(path) == "testdata" { + return filepath.SkipDir // testdata is ignored by the go tool; not real packages } if fi.Mode().IsRegular() && strings.HasSuffix(path, ".go") { if strings.HasSuffix(path, "_test.go") { @@ -48,6 +63,9 @@ func TestPackageDocs(t *testing.T) { if err != nil { t.Fatalf("failed to ParseFile %q: %v", fileName, err) } + if hasIgnoreBuildTag(f) { + continue + } dir := filepath.Dir(fileName) if _, ok := byDir[dir]; !ok { byDir[dir] = nil @@ -61,14 +79,8 @@ func TestPackageDocs(t *testing.T) { } } for dir, ff := range byDir { - switch dir { - case "tstest/integration/vms": - // This package has a couple go:build ignore commands and this test doesn't - // handle parsing those. Just allowlist that package for now (2024-07-10). - continue - } if len(ff) > 1 { - t.Logf("multiple files with package doc in %s: %q", dir, ff) + t.Errorf("multiple files with package doc in %s: %q", dir, ff) } if len(ff) == 0 { if strings.HasPrefix(dir, "gokrazy/") { diff --git a/posture/serialnumber_stub.go b/posture/serialnumber_stub.go index e040aacfb30e2..6df9b4079650b 100644 --- a/posture/serialnumber_stub.go +++ b/posture/serialnumber_stub.go @@ -12,6 +12,7 @@ package posture import ( "errors" + "fmt" "tailscale.com/types/logger" "tailscale.com/util/syspolicy/policyclient" @@ -19,5 +20,5 @@ import ( // GetSerialNumber returns client machine serial number(s). func GetSerialNumbers(polc policyclient.Client, _ logger.Logf) ([]string, error) { - return nil, errors.New("not implemented") + return nil, fmt.Errorf("not implemented: %w", errors.ErrUnsupported) } diff --git a/prober/derp.go b/prober/derp.go index 73ea02cf5ad4f..dadda6fce2208 100644 --- a/prober/derp.go +++ b/prober/derp.go @@ -17,6 +17,7 @@ import ( "io" "log" "maps" + "math" "net" "net/http" "net/netip" @@ -423,7 +424,7 @@ func runDerpProbeQueuingDelayContinously(ctx context.Context, from, to *tailcfg. // for packets up to their timeout. As records age out of the front of this // list, if the associated packet arrives, we won't have a txRecord for it // and will consider it to have timed out. - txRecords := make([]txRecord, 0, packetsPerSecond*int(packetTimeout.Seconds())) + txRecords := make([]txRecord, 0, int(math.Ceil(float64(packetsPerSecond)*packetTimeout.Seconds()))+1) var txRecordsMu sync.Mutex // applyTimeouts walks over txRecords and expires any records that are older @@ -435,7 +436,7 @@ func runDerpProbeQueuingDelayContinously(ctx context.Context, from, to *tailcfg. now := time.Now() recs := txRecords[:0] for _, r := range txRecords { - if now.Sub(r.at) > packetTimeout { + if now.Sub(r.at) >= packetTimeout { packetsDropped.Add(1) } else { recs = append(recs, r) @@ -451,9 +452,7 @@ func runDerpProbeQueuingDelayContinously(ctx context.Context, from, to *tailcfg. pkt := make([]byte, 260) // the same size as a CallMeMaybe packet observed on a Tailscale client. crand.Read(pkt) - wg.Add(1) - go func() { - defer wg.Done() + wg.Go(func() { t := time.NewTicker(time.Second / time.Duration(packetsPerSecond)) defer t.Stop() @@ -481,13 +480,11 @@ func runDerpProbeQueuingDelayContinously(ctx context.Context, from, to *tailcfg. } } } - }() + }) // Receive the packets. recvFinishedC := make(chan error, 1) - wg.Add(1) - go func() { - defer wg.Done() + wg.Go(func() { defer close(recvFinishedC) // to break out of 'select' below. fromDERPPubKey := fromc.SelfPublicKey() for { @@ -531,7 +528,7 @@ func runDerpProbeQueuingDelayContinously(ctx context.Context, from, to *tailcfg. // Loop. } } - }() + }) select { case <-ctx.Done(): diff --git a/prober/prober.go b/prober/prober.go index 3a43401a14ac3..40eef2faf43b1 100644 --- a/prober/prober.go +++ b/prober/prober.go @@ -122,12 +122,8 @@ func (p *Prober) Run(name string, interval time.Duration, labels Labels, pc Prob "name": name, "class": pc.Class, } - for k, v := range pc.Labels { - lb[k] = v - } - for k, v := range labels { - lb[k] = v - } + maps.Copy(lb, pc.Labels) + maps.Copy(lb, labels) probe := newProbe(p, name, interval, lb, pc) p.probes[name] = probe diff --git a/pull-toolchain.sh b/pull-toolchain.sh index c80c913bb17b2..8f34129c66053 100755 --- a/pull-toolchain.sh +++ b/pull-toolchain.sh @@ -20,15 +20,41 @@ if [ "$upstream" != "$current" ]; then echo "$upstream" >"$go_toolchain_rev_file" fi -# Only update go.toolchain.version and go.toolchain.rev.sri for the main toolchain, +# When updating the regular (non-next) toolchain, also bump go.toolchain.next.rev +# if it has fallen behind on the same branch. This happens when "next" was tracking +# a release candidate (e.g. Go 1.26.0rc2) and the regular toolchain later gets +# bumped to a newer release (e.g. Go 1.26.2) on the same branch. At that point +# the "next" rev shouldn't still point at the older RC. +if [ "${TS_GO_NEXT:-}" != "1" ]; then + read -r next_branch /dev/null; then + if git -C "$tmpdir" merge-base --is-ancestor "$next_rev" "$new_rev" 2>/dev/null; then + echo "$new_rev" >go.toolchain.next.rev + echo "pull-toolchain.sh: also bumped go.toolchain.next.rev to match (was behind on same branch)" >&2 + fi + fi + rm -rf "$tmpdir" + fi + fi +fi + +# Only update go.toolchain.version and flakehashes.json for the main toolchain, # skipping it if TS_GO_NEXT=1. Those two files are only used by Nix, and as of 2026-01-26 # don't yet support TS_GO_NEXT=1 with flake.nix or in our corp CI. if [ "${TS_GO_NEXT:-}" != "1" ]; then ./tool/go version 2>/dev/null | awk '{print $3}' | sed 's/^go//' > go.toolchain.version ./tool/go mod edit -go "$(cat go.toolchain.version)" - ./update-flake.sh + ./tool/go run ./tool/updateflakes fi -if [ -n "$(git diff-index --name-only HEAD -- "$go_toolchain_rev_file" go.toolchain.rev.sri go.toolchain.version)" ]; then +if [ -n "$(git diff-index --name-only HEAD -- "$go_toolchain_rev_file" go.toolchain.next.rev flakehashes.json go.toolchain.version)" ]; then echo "pull-toolchain.sh: changes imported. Use git commit to make them permanent." >&2 fi diff --git a/release/dist/qnap/files/scripts/build-qpkg.sh b/release/dist/qnap/files/scripts/build-qpkg.sh index d478bfe6b26e6..61786ead829ca 100755 --- a/release/dist/qnap/files/scripts/build-qpkg.sh +++ b/release/dist/qnap/files/scripts/build-qpkg.sh @@ -4,17 +4,9 @@ set -eu # Clean up folders and files created during build. function cleanup() { - rm -rf /Tailscale/$ARCH - rm -f /Tailscale/sed* - rm -f /Tailscale/qpkg.cfg - - # If this build was signed, a .qpkg.codesigning file will be created as an - # artifact of the build - # (see https://github.com/qnap-dev/qdk2/blob/93ac75c76941b90ee668557f7ce01e4b23881054/QDK_2.x/bin/qbuild#L992). - # - # go/client-release doesn't seem to need these, so we delete them here to - # avoid uploading them to pkgs.tailscale.com. - rm -f /out/*.qpkg.codesigning + rm -rf /Tailscale/$ARCH + rm -f /Tailscale/sed* + rm -f /Tailscale/qpkg.cfg } trap cleanup EXIT @@ -22,6 +14,6 @@ mkdir -p /Tailscale/$ARCH cp /tailscaled /Tailscale/$ARCH/tailscaled cp /tailscale /Tailscale/$ARCH/tailscale -sed "s/\$QPKG_VER/$TSTAG-$QNAPTAG/g" /Tailscale/qpkg.cfg.in > /Tailscale/qpkg.cfg +sed "s/\$QPKG_VER/$TSTAG-$QNAPTAG/g" /Tailscale/qpkg.cfg.in >/Tailscale/qpkg.cfg qbuild --root /Tailscale --build-arch $ARCH --build-dir /out diff --git a/release/dist/qnap/pkgs.go b/release/dist/qnap/pkgs.go index 1d69b3eaf3500..b505b1ac0e908 100644 --- a/release/dist/qnap/pkgs.go +++ b/release/dist/qnap/pkgs.go @@ -118,7 +118,16 @@ func (t *target) buildQPKG(b *dist.Build, qnapBuilds *qnapBuilds, inner *innerPk return nil, fmt.Errorf("docker run %v: %s", err, out) } - return []string{filePath, filePath + ".md5"}, nil + ret := []string{filePath, filePath + ".md5"} + // If the build was signed, a .codesigning file is produced containing + // the last 32 characters of the base64-encoded CMS signature. This is + // used by pkgserve to populate entries in the QNAP + // repository XML. + codesigning := filePath + ".codesigning" + if _, err := os.Stat(codesigning); err == nil { + ret = append(ret, codesigning) + } + return ret, nil } type qnapBuildsMemoizeKey struct{} diff --git a/safesocket/safesocket.go b/safesocket/safesocket.go index 6be8ae5b8fac3..60291e1340558 100644 --- a/safesocket/safesocket.go +++ b/safesocket/safesocket.go @@ -120,7 +120,7 @@ func PlatformUsesPeerCreds() bool { // runtime.GOOS value instead of using the current one. func GOOSUsesPeerCreds(goos string) bool { switch goos { - case "linux", "darwin", "freebsd": + case "linux", "darwin", "freebsd", "solaris", "illumos": return true } return false diff --git a/safesocket/safesocket_darwin.go b/safesocket/safesocket_darwin.go index 8cbabff63364e..aa67baaf82596 100644 --- a/safesocket/safesocket_darwin.go +++ b/safesocket/safesocket_darwin.go @@ -102,8 +102,8 @@ func SetCredentials(token string, port int) { // InitListenerDarwin initializes the listener for the CLI commands // and localapi HTTP server and sets the port/token. This will override -// any credentials set explicitly via SetCredentials(). Calling this mulitple times -// has no effect. The listener and it's corresponding token/port is initialized only once. +// any credentials set explicitly via SetCredentials(). Calling this multiple times +// has no effect. The listener and its corresponding token/port is initialized only once. func InitListenerDarwin(sharedDir string) (*net.Listener, error) { ssd.mu.Lock() defer ssd.mu.Unlock() diff --git a/safeweb/http.go b/safeweb/http.go index f76591cbd0e16..d52412bd35cc2 100644 --- a/safeweb/http.go +++ b/safeweb/http.go @@ -74,6 +74,7 @@ import ( "context" crand "crypto/rand" "fmt" + "html/template" "log" "maps" "net" @@ -268,7 +269,7 @@ func NewServer(config Config) (*Server, error) { csp: config.CSP.String(), // only set Secure flag on CSRF cookies if we are in a secure context // as otherwise the browser will reject the cookie - csrfProtect: csrf.Protect(config.CSRFSecret, csrf.Secure(config.SecureContext), csrf.SameSite(sameSite)), + csrfProtect: csrf.Protect(config.CSRFSecret, csrf.Secure(config.SecureContext), csrf.SameSite(sameSite), csrf.Path("/")), } s.h = cmp.Or(config.HTTPServer, &http.Server{}) if s.h.Handler != nil { @@ -428,3 +429,12 @@ func (s *Server) Close() error { // Shutdown gracefully shuts down the server without interrupting any active // connections. It has the same semantics as[http.Server.Shutdown]. func (s *Server) Shutdown(ctx context.Context) error { return s.h.Shutdown(ctx) } + +// CSRFToken returns the masked CSRF token for the current request. Use this +// to pass the token in JSON responses or custom headers. +func CSRFToken(r *http.Request) string { return csrf.Token(r) } + +// CSRFTemplateField returns a hidden HTML input element containing the CSRF +// token for the current request. Use this in HTML forms served by the +// BrowserMux. +func CSRFTemplateField(r *http.Request) template.HTML { return csrf.TemplateField(r) } diff --git a/safeweb/http_test.go b/safeweb/http_test.go index cbac7210a4807..fb298eb6ea18f 100644 --- a/safeweb/http_test.go +++ b/safeweb/http_test.go @@ -41,25 +41,25 @@ func TestPostRequestContentTypeValidation(t *testing.T) { wantErr bool }{ { - name: "API routes should accept `application/json` content-type", + name: "API-accept-application-json", browserRoute: false, contentType: "application/json", wantErr: false, }, { - name: "API routes should reject `application/x-www-form-urlencoded` content-type", + name: "API-reject-form-urlencoded", browserRoute: false, contentType: "application/x-www-form-urlencoded", wantErr: true, }, { - name: "Browser routes should accept `application/x-www-form-urlencoded` content-type", + name: "browser-accept-form-urlencoded", browserRoute: true, contentType: "application/x-www-form-urlencoded", wantErr: false, }, { - name: "non Browser routes should accept `application/json` content-type", + name: "browser-accept-application-json", browserRoute: true, contentType: "application/json", wantErr: false, @@ -106,21 +106,21 @@ func TestAPIMuxCrossOriginResourceSharingHeaders(t *testing.T) { corsMethods []string }{ { - name: "do not set CORS headers for non-OPTIONS requests", + name: "no-CORS-headers-for-non-OPTIONS", corsOrigins: []string{"https://foobar.com"}, corsMethods: []string{"GET", "POST", "HEAD"}, httpMethod: "GET", wantCORSHeaders: false, }, { - name: "set CORS headers for non-OPTIONS requests", + name: "CORS-headers-for-OPTIONS", corsOrigins: []string{"https://foobar.com"}, corsMethods: []string{"GET", "POST", "HEAD"}, httpMethod: "OPTIONS", wantCORSHeaders: true, }, { - name: "do not serve CORS headers for OPTIONS requests with no configured origins", + name: "no-CORS-headers-for-OPTIONS-without-origins", httpMethod: "OPTIONS", wantCORSHeaders: false, }, @@ -162,19 +162,19 @@ func TestCSRFProtection(t *testing.T) { wantStatus int }{ { - name: "POST requests to non-API routes require CSRF token and fail if not provided", + name: "non-API-POST-without-CSRF-fails", apiRoute: false, passCSRFToken: false, wantStatus: http.StatusForbidden, }, { - name: "POST requests to non-API routes require CSRF token and pass if provided", + name: "non-API-POST-with-CSRF-passes", apiRoute: false, passCSRFToken: true, wantStatus: http.StatusOK, }, { - name: "POST requests to /api/ routes do not require CSRF token", + name: "API-POST-without-CSRF-passes", apiRoute: true, passCSRFToken: false, wantStatus: http.StatusOK, @@ -246,11 +246,11 @@ func TestContentSecurityPolicyHeader(t *testing.T) { wantCSP string }{ { - name: "default CSP", + name: "default-CSP", wantCSP: `base-uri 'self'; block-all-mixed-content; default-src 'self'; form-action 'self'; frame-ancestors 'none';`, }, { - name: "custom CSP", + name: "custom-CSP", csp: CSP{ "default-src": {"'self'", "https://tailscale.com"}, "upgrade-insecure-requests": nil, @@ -258,7 +258,7 @@ func TestContentSecurityPolicyHeader(t *testing.T) { wantCSP: `default-src 'self' https://tailscale.com; upgrade-insecure-requests;`, }, { - name: "`/api/*` routes do not get CSP headers", + name: "api-routes-no-CSP-headers", apiRoute: true, wantCSP: "", }, @@ -301,12 +301,12 @@ func TestCSRFCookieSecureMode(t *testing.T) { wantSecure bool }{ { - name: "CSRF cookie should be secure when server is in secure context", + name: "secure-context-cookie-secure", secureMode: true, wantSecure: true, }, { - name: "CSRF cookie should not be secure when server is not in secure context", + name: "non-secure-context-cookie-not-secure", secureMode: false, wantSecure: false, }, @@ -343,12 +343,12 @@ func TestRefererPolicy(t *testing.T) { wantRefererPolicy bool }{ { - name: "BrowserMux routes get Referer-Policy headers", + name: "BrowserMux-gets-Referer-Policy", browserRoute: true, wantRefererPolicy: true, }, { - name: "APIMux routes do not get Referer-Policy headers", + name: "APIMux-no-Referer-Policy", browserRoute: false, wantRefererPolicy: false, }, @@ -420,54 +420,54 @@ func TestRouting(t *testing.T) { want string }{ { - desc: "only browser mux", + desc: "only-browser-mux", browserPatterns: []string{"/"}, requestPath: "/index.html", want: "browser", }, { - desc: "only API mux", + desc: "only-API-mux", apiPatterns: []string{"/api/"}, requestPath: "/api/foo", want: "api", }, { - desc: "browser mux match", + desc: "browser-mux-match", browserPatterns: []string{"/content/"}, apiPatterns: []string{"/api/"}, requestPath: "/content/index.html", want: "browser", }, { - desc: "API mux match", + desc: "API-mux-match", browserPatterns: []string{"/content/"}, apiPatterns: []string{"/api/"}, requestPath: "/api/foo", want: "api", }, { - desc: "browser wildcard match", + desc: "browser-wildcard-match", browserPatterns: []string{"/"}, apiPatterns: []string{"/api/"}, requestPath: "/index.html", want: "browser", }, { - desc: "API wildcard match", + desc: "API-wildcard-match", browserPatterns: []string{"/content/"}, apiPatterns: []string{"/"}, requestPath: "/api/foo", want: "api", }, { - desc: "path conflict", + desc: "path-conflict", browserPatterns: []string{"/foo/"}, apiPatterns: []string{"/foo/bar/"}, requestPath: "/foo/bar/baz", want: "api", }, { - desc: "no match", + desc: "no-match", browserPatterns: []string{"/foo/"}, apiPatterns: []string{"/bar/"}, requestPath: "/baz", @@ -521,43 +521,43 @@ func TestGetMoreSpecificPattern(t *testing.T) { want: unknownHandler, }, { - desc: "identical prefix", + desc: "identical-prefix", a: "/foo/bar/", b: "/foo/bar/", want: unknownHandler, }, { - desc: "trailing slash", + desc: "trailing-slash", a: "/foo", b: "/foo/", // path.Clean will strip the trailing slash. want: unknownHandler, }, { - desc: "same prefix", + desc: "same-prefix", a: "/foo/bar/quux", b: "/foo/bar/", // path.Clean will strip the trailing slash. want: apiHandler, }, { - desc: "almost same prefix, but not a path component", + desc: "almost-same-prefix-not-path-component", a: "/goat/sheep/cheese", b: "/goat/sheepcheese/", // path.Clean will strip the trailing slash. want: apiHandler, }, { - desc: "attempt to make less-specific pattern look more specific", + desc: "traversal-less-specific-pattern", a: "/goat/cat/buddy", b: "/goat/../../../../../../../cat", // path.Clean catches this foolishness want: apiHandler, }, { - desc: "2 names for / (1)", + desc: "two-names-for-root-1", a: "/", b: "/../../../../../../", want: unknownHandler, }, { - desc: "2 names for / (2)", + desc: "two-names-for-root-2", a: "/", b: "///////", want: unknownHandler, @@ -586,15 +586,15 @@ func TestStrictTransportSecurityOptions(t *testing.T) { expect string }{ { - name: "off by default", + name: "off-by-default", }, { - name: "default HSTS options in the secure context", + name: "default-HSTS-in-secure-context", secureContext: true, expect: DefaultStrictTransportSecurityOptions, }, { - name: "custom options sent in the secure context", + name: "custom-options-in-secure-context", options: DefaultStrictTransportSecurityOptions + "; includeSubDomains", secureContext: true, expect: DefaultStrictTransportSecurityOptions + "; includeSubDomains", diff --git a/sessionrecording/connect_test.go b/sessionrecording/connect_test.go index 64bcb1c3185d3..3d1feff12f8de 100644 --- a/sessionrecording/connect_test.go +++ b/sessionrecording/connect_test.go @@ -35,7 +35,7 @@ func TestConnectToRecorder(t *testing.T) { wantErr bool }{ { - desc: "v1 recorder", + desc: "v1-recorder", setup: func(t *testing.T) (*http.ServeMux, <-chan []byte) { uploadHash := make(chan []byte, 1) mux := http.NewServeMux() @@ -50,7 +50,7 @@ func TestConnectToRecorder(t *testing.T) { }, }, { - desc: "v2 recorder", + desc: "v2-recorder", http2: true, setup: func(t *testing.T) (*http.ServeMux, <-chan []byte) { uploadHash := make(chan []byte, 1) @@ -100,7 +100,7 @@ func TestConnectToRecorder(t *testing.T) { }, }, { - desc: "v2 recorder no acks", + desc: "v2-recorder-no-acks", http2: true, wantErr: true, setup: func(t *testing.T) (*http.ServeMux, <-chan []byte) { diff --git a/shell.nix b/shell.nix index 7ddf62c52df5c..648f101a8dbf1 100644 --- a/shell.nix +++ b/shell.nix @@ -16,4 +16,4 @@ ) { src = ./.; }).shellNix -# nix-direnv cache busting line: sha256-rhuWEEN+CtumVxOw6Dy/IRxWIrZ2x6RJb6ULYwXCQc4= +# nix-direnv cache busting line: sha256-mbxLXR2TBgiwyVGfLmMR5xWk+0f66mPDas95Wla70Lk= diff --git a/ssh/tailssh/accept_env_test.go b/ssh/tailssh/accept_env_test.go index 25787db302357..fef13877a85c5 100644 --- a/ssh/tailssh/accept_env_test.go +++ b/ssh/tailssh/accept_env_test.go @@ -111,25 +111,25 @@ func TestFilterEnv(t *testing.T) { wantErrMessage string }{ { - name: "simple direct matches", + name: "simple-direct-matches", acceptEnv: []string{"FOO", "FOO2", "FOO_3"}, environ: []string{"FOO=BAR", "FOO2=BAZ", "FOO_3=123", "FOOOO4-2=AbCdEfG"}, expectedFiltered: []string{"FOO=BAR", "FOO2=BAZ", "FOO_3=123"}, }, { - name: "bare wildcard", + name: "bare-wildcard", acceptEnv: []string{"*"}, environ: []string{"FOO=BAR", "FOO2=BAZ", "FOO_3=123", "FOOOO4-2=AbCdEfG"}, expectedFiltered: []string{"FOO=BAR", "FOO2=BAZ", "FOO_3=123", "FOOOO4-2=AbCdEfG"}, }, { - name: "complex matches", + name: "complex-matches", acceptEnv: []string{"FO?", "FOOO*", "FO*5?7"}, environ: []string{"FOO=BAR", "FOO2=BAZ", "FOO_3=123", "FOOOO4-2=AbCdEfG", "FO1-kmndGamc79567=ABC", "FO57=BAR2"}, expectedFiltered: []string{"FOO=BAR", "FOOOO4-2=AbCdEfG", "FO1-kmndGamc79567=ABC"}, }, { - name: "environ format invalid", + name: "environ-format-invalid", acceptEnv: []string{"FO?", "FOOO*", "FO*5?7"}, environ: []string{"FOOBAR"}, expectedFiltered: nil, diff --git a/ssh/tailssh/c2n.go b/ssh/tailssh/c2n.go new file mode 100644 index 0000000000000..621be74d4baba --- /dev/null +++ b/ssh/tailssh/c2n.go @@ -0,0 +1,109 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || plan9 + +package tailssh + +import ( + "bytes" + "encoding/json" + "net/http" + "os/exec" + "runtime" + "slices" + + "go4.org/mem" + "tailscale.com/ipn/ipnlocal" + "tailscale.com/tailcfg" + "tailscale.com/util/lineiter" +) + +func handleC2NSSHUsernames(b *ipnlocal.LocalBackend, w http.ResponseWriter, r *http.Request) { + var req tailcfg.C2NSSHUsernamesRequest + if r.Method == "POST" { + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + } + res, err := getSSHUsernames(b, &req) + if err != nil { + http.Error(w, err.Error(), 500) + return + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(res) +} + +// getSSHUsernames discovers and returns the list of usernames that are +// potential Tailscale SSH user targets. +func getSSHUsernames(b *ipnlocal.LocalBackend, req *tailcfg.C2NSSHUsernamesRequest) (*tailcfg.C2NSSHUsernamesResponse, error) { + res := new(tailcfg.C2NSSHUsernamesResponse) + if b == nil || !b.ShouldRunSSH() { + return res, nil + } + + max := 10 + if req != nil && req.Max != 0 { + max = req.Max + } + + add := func(u string) { + if req != nil && req.Exclude[u] { + return + } + switch u { + case "nobody", "daemon", "sync": + return + } + if slices.Contains(res.Usernames, u) { + return + } + if len(res.Usernames) > max { + // Enough for a hint. + return + } + res.Usernames = append(res.Usernames, u) + } + + if opUser := b.OperatorUserName(); opUser != "" { + add(opUser) + } + + // Check popular usernames and see if they exist with a real shell. + switch runtime.GOOS { + case "darwin": + out, err := exec.Command("dscl", ".", "list", "/Users").Output() + if err != nil { + return nil, err + } + for line := range lineiter.Bytes(out) { + line = bytes.TrimSpace(line) + if len(line) == 0 || line[0] == '_' { + continue + } + add(string(line)) + } + default: + for lr := range lineiter.File("/etc/passwd") { + line, err := lr.Value() + if err != nil { + break + } + line = bytes.TrimSpace(line) + if len(line) == 0 || line[0] == '#' || line[0] == '_' { + continue + } + if mem.HasSuffix(mem.B(line), mem.S("/nologin")) || + mem.HasSuffix(mem.B(line), mem.S("/false")) { + continue + } + before, _, ok := bytes.Cut(line, []byte{':'}) + if ok { + add(string(before)) + } + } + } + return res, nil +} diff --git a/ipn/ipnlocal/ssh.go b/ssh/tailssh/hostkeys.go similarity index 50% rename from ipn/ipnlocal/ssh.go rename to ssh/tailssh/hostkeys.go index 52b3066584e08..8046a021a9308 100644 --- a/ipn/ipnlocal/ssh.go +++ b/ssh/tailssh/hostkeys.go @@ -1,9 +1,9 @@ // Copyright (c) Tailscale Inc & contributors // SPDX-License-Identifier: BSD-3-Clause -//go:build ((linux && !android) || (darwin && !ios) || freebsd || openbsd || plan9) && !ts_omit_ssh +//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || plan9 -package ipnlocal +package tailssh import ( "bytes" @@ -17,17 +17,12 @@ import ( "errors" "fmt" "os" - "os/exec" "path/filepath" - "runtime" - "slices" "strings" "sync" - "go4.org/mem" "golang.org/x/crypto/ssh" - "tailscale.com/tailcfg" - "tailscale.com/util/lineiter" + "tailscale.com/types/logger" "tailscale.com/util/mak" ) @@ -36,91 +31,32 @@ import ( // running as root. var keyTypes = []string{"rsa", "ecdsa", "ed25519"} -// getSSHUsernames discovers and returns the list of usernames that are -// potential Tailscale SSH user targets. -// -// Invariant: must not be called with b.mu held. -func (b *LocalBackend) getSSHUsernames(req *tailcfg.C2NSSHUsernamesRequest) (*tailcfg.C2NSSHUsernamesResponse, error) { - res := new(tailcfg.C2NSSHUsernamesResponse) - if !b.tailscaleSSHEnabled() { - return res, nil - } - - max := 10 - if req != nil && req.Max != 0 { - max = req.Max - } - - add := func(u string) { - if req != nil && req.Exclude[u] { - return - } - switch u { - case "nobody", "daemon", "sync": - return - } - if slices.Contains(res.Usernames, u) { - return - } - if len(res.Usernames) > max { - // Enough for a hint. - return - } - res.Usernames = append(res.Usernames, u) - } - - if opUser := b.operatorUserName(); opUser != "" { - add(opUser) - } - - // Check popular usernames and see if they exist with a real shell. - switch runtime.GOOS { - case "darwin": - out, err := exec.Command("dscl", ".", "list", "/Users").Output() - if err != nil { - return nil, err - } - for line := range lineiter.Bytes(out) { - line = bytes.TrimSpace(line) - if len(line) == 0 || line[0] == '_' { - continue - } - add(string(line)) - } - default: - for lr := range lineiter.File("/etc/passwd") { - line, err := lr.Value() - if err != nil { - break - } - line = bytes.TrimSpace(line) - if len(line) == 0 || line[0] == '#' || line[0] == '_' { - continue - } - if mem.HasSuffix(mem.B(line), mem.S("/nologin")) || - mem.HasSuffix(mem.B(line), mem.S("/false")) { - continue - } - colon := bytes.IndexByte(line, ':') - if colon != -1 { - add(string(line[:colon])) - } - } +// getHostKeys returns the SSH host keys, using system keys when running as root +// and generating Tailscale-specific keys as needed. +func getHostKeys(varRoot string, logf logger.Logf) ([]ssh.Signer, error) { + var existing map[string]ssh.Signer + if os.Geteuid() == 0 { + existing = getSystemHostKeys(logf) } - return res, nil + return getTailscaleHostKeys(varRoot, existing) } -func (b *LocalBackend) GetSSH_HostKeys() (keys []ssh.Signer, err error) { - var existing map[string]ssh.Signer - if os.Geteuid() == 0 { - existing = b.getSystemSSH_HostKeys() +// getHostKeyPublicStrings returns the SSH host key public key strings. +func getHostKeyPublicStrings(varRoot string, logf logger.Logf) ([]string, error) { + signers, err := getHostKeys(varRoot, logf) + if err != nil { + return nil, err + } + var keyStrings []string + for _, signer := range signers { + keyStrings = append(keyStrings, strings.TrimSpace(string(ssh.MarshalAuthorizedKey(signer.PublicKey())))) } - return b.getTailscaleSSH_HostKeys(existing) + return keyStrings, nil } -// getTailscaleSSH_HostKeys returns the three (rsa, ecdsa, ed25519) SSH host +// getTailscaleHostKeys returns the three (rsa, ecdsa, ed25519) SSH host // keys, reusing the provided ones in existing if present in the map. -func (b *LocalBackend) getTailscaleSSH_HostKeys(existing map[string]ssh.Signer) (keys []ssh.Signer, err error) { +func getTailscaleHostKeys(varRoot string, existing map[string]ssh.Signer) (keys []ssh.Signer, err error) { var keyDir string // lazily initialized $TAILSCALE_VAR/ssh dir. for _, typ := range keyTypes { if s, ok := existing[typ]; ok { @@ -128,16 +64,15 @@ func (b *LocalBackend) getTailscaleSSH_HostKeys(existing map[string]ssh.Signer) continue } if keyDir == "" { - root := b.TailscaleVarRoot() - if root == "" { + if varRoot == "" { return nil, errors.New("no var root for ssh keys") } - keyDir = filepath.Join(root, "ssh") + keyDir = filepath.Join(varRoot, "ssh") if err := os.MkdirAll(keyDir, 0700); err != nil { return nil, err } } - hostKey, err := b.hostKeyFileOrCreate(keyDir, typ) + hostKey, err := hostKeyFileOrCreate(keyDir, typ) if err != nil { return nil, fmt.Errorf("error creating SSH host key type %q in %q: %w", typ, keyDir, err) } @@ -150,9 +85,16 @@ func (b *LocalBackend) getTailscaleSSH_HostKeys(existing map[string]ssh.Signer) return keys, nil } +// keyGenMu protects concurrent generation of host keys with +// [hostKeyFileOrCreate], making sure two callers don't try to concurrently find +// a missing key and generate it at the same time, returning different keys to +// their callers. +// +// Technically we actually want to have a mutex per directory (the keyDir +// passed), but that's overkill for how rarely keys are loaded or generated. var keyGenMu sync.Mutex -func (b *LocalBackend) hostKeyFileOrCreate(keyDir, typ string) ([]byte, error) { +func hostKeyFileOrCreate(keyDir, typ string) ([]byte, error) { keyGenMu.Lock() defer keyGenMu.Unlock() @@ -195,7 +137,7 @@ func (b *LocalBackend) hostKeyFileOrCreate(keyDir, typ string) ([]byte, error) { return pemGen, err } -func (b *LocalBackend) getSystemSSH_HostKeys() (ret map[string]ssh.Signer) { +func getSystemHostKeys(logf logger.Logf) (ret map[string]ssh.Signer) { for _, typ := range keyTypes { filename := "/etc/ssh/ssh_host_" + typ + "_key" hostKey, err := os.ReadFile(filename) @@ -204,31 +146,10 @@ func (b *LocalBackend) getSystemSSH_HostKeys() (ret map[string]ssh.Signer) { } signer, err := ssh.ParsePrivateKey(hostKey) if err != nil { - b.logf("warning: error reading host key %s: %v (generating one instead)", filename, err) + logf("warning: error reading host key %s: %v (generating one instead)", filename, err) continue } mak.Set(&ret, typ, signer) } return ret } - -func (b *LocalBackend) getSSHHostKeyPublicStrings() ([]string, error) { - signers, err := b.GetSSH_HostKeys() - if err != nil { - return nil, err - } - var keyStrings []string - for _, signer := range signers { - keyStrings = append(keyStrings, strings.TrimSpace(string(ssh.MarshalAuthorizedKey(signer.PublicKey())))) - } - return keyStrings, nil -} - -// tailscaleSSHEnabled reports whether Tailscale SSH is currently enabled based -// on prefs. It returns false if there are no prefs set. -func (b *LocalBackend) tailscaleSSHEnabled() bool { - b.mu.Lock() - defer b.mu.Unlock() - p := b.pm.CurrentPrefs() - return p.Valid() && p.RunSSH() -} diff --git a/ssh/tailssh/hostkeys_test.go b/ssh/tailssh/hostkeys_test.go new file mode 100644 index 0000000000000..24a876454ea6e --- /dev/null +++ b/ssh/tailssh/hostkeys_test.go @@ -0,0 +1,39 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux || (darwin && !ios) + +package tailssh + +import ( + "reflect" + "testing" +) + +func TestSSHKeyGen(t *testing.T) { + dir := t.TempDir() + keys, err := getTailscaleHostKeys(dir, nil) + if err != nil { + t.Fatal(err) + } + got := map[string]bool{} + for _, k := range keys { + got[k.PublicKey().Type()] = true + } + want := map[string]bool{ + "ssh-rsa": true, + "ecdsa-sha2-nistp256": true, + "ssh-ed25519": true, + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("keys = %v; want %v", got, want) + } + + keys2, err := getTailscaleHostKeys(dir, nil) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(keys, keys2) { + t.Errorf("got different keys on second call") + } +} diff --git a/ssh/tailssh/incubator.go b/ssh/tailssh/incubator.go index b414ce3fbf42a..48c65e8e51446 100644 --- a/ssh/tailssh/incubator.go +++ b/ssh/tailssh/incubator.go @@ -35,13 +35,13 @@ import ( "github.com/creack/pty" "github.com/pkg/sftp" + gliderssh "github.com/tailscale/gliderssh" "github.com/u-root/u-root/pkg/termios" - gossh "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh" "golang.org/x/sys/unix" "tailscale.com/cmd/tailscaled/childproc" "tailscale.com/hostinfo" "tailscale.com/tailcfg" - "tailscale.com/tempfork/gliderlabs/ssh" "tailscale.com/types/logger" "tailscale.com/version/distro" ) @@ -158,8 +158,7 @@ func (ss *sshSession) newIncubatorCommand(logf logger.Logf) (cmd *exec.Cmd, err cmd.Dir = "/" case errors.Is(err, fs.ErrPermission) || errors.Is(err, fs.ErrNotExist): // Ensure that cmd.Dir is the source of the error. - var pathErr *fs.PathError - if errors.As(err, &pathErr) && pathErr.Path == cmd.Dir { + if pathErr, ok := errors.AsType[*fs.PathError](err); ok && pathErr.Path == cmd.Dir { // If we cannot run loginShell in localUser.HomeDir, // we will try to run this command in the root directory. cmd.Dir = "/" @@ -203,7 +202,7 @@ func (ss *sshSession) newIncubatorCommand(logf logger.Logf) (cmd *exec.Cmd, err incubatorArgs = append(incubatorArgs, "--is-selinux-enforcing") } - nm := ss.conn.srv.lb.NetMap() + nm := ss.conn.srv.lb.NetMapNoPeers() forceV1Behavior := nm.HasCap(tailcfg.NodeAttrSSHBehaviorV1) && !nm.HasCap(tailcfg.NodeAttrSSHBehaviorV2) if forceV1Behavior { incubatorArgs = append(incubatorArgs, "--force-v1-behavior") @@ -312,7 +311,7 @@ func parseIncubatorArgs(args []string) (incubatorArgs, error) { flags.StringVar(&ia.encodedEnv, "encoded-env", "", "JSON encoded array of environment variables in '['key=value']' format") flags.Parse(args) - for _, g := range strings.Split(groups, ",") { + for g := range strings.SplitSeq(groups, ",") { gid, err := strconv.Atoi(g) if err != nil { return ia, fmt.Errorf("unable to parse group id %q: %w", g, err) @@ -898,7 +897,7 @@ func (ss *sshSession) launchProcess() error { return nil } -func resizeWindow(fd int, winCh <-chan ssh.Window) { +func resizeWindow(fd int, winCh <-chan gliderssh.Window) { for win := range winCh { unix.IoctlSetWinsize(fd, syscall.TIOCSWINSZ, &unix.Winsize{ Row: uint16(win.Height), @@ -913,62 +912,62 @@ func resizeWindow(fd int, winCh <-chan ssh.Window) { // to mnemonic names expected by the termios package. // These are meant to be platform independent. var opcodeShortName = map[uint8]string{ - gossh.VINTR: "intr", - gossh.VQUIT: "quit", - gossh.VERASE: "erase", - gossh.VKILL: "kill", - gossh.VEOF: "eof", - gossh.VEOL: "eol", - gossh.VEOL2: "eol2", - gossh.VSTART: "start", - gossh.VSTOP: "stop", - gossh.VSUSP: "susp", - gossh.VDSUSP: "dsusp", - gossh.VREPRINT: "rprnt", - gossh.VWERASE: "werase", - gossh.VLNEXT: "lnext", - gossh.VFLUSH: "flush", - gossh.VSWTCH: "swtch", - gossh.VSTATUS: "status", - gossh.VDISCARD: "discard", - gossh.IGNPAR: "ignpar", - gossh.PARMRK: "parmrk", - gossh.INPCK: "inpck", - gossh.ISTRIP: "istrip", - gossh.INLCR: "inlcr", - gossh.IGNCR: "igncr", - gossh.ICRNL: "icrnl", - gossh.IUCLC: "iuclc", - gossh.IXON: "ixon", - gossh.IXANY: "ixany", - gossh.IXOFF: "ixoff", - gossh.IMAXBEL: "imaxbel", - gossh.IUTF8: "iutf8", - gossh.ISIG: "isig", - gossh.ICANON: "icanon", - gossh.XCASE: "xcase", - gossh.ECHO: "echo", - gossh.ECHOE: "echoe", - gossh.ECHOK: "echok", - gossh.ECHONL: "echonl", - gossh.NOFLSH: "noflsh", - gossh.TOSTOP: "tostop", - gossh.IEXTEN: "iexten", - gossh.ECHOCTL: "echoctl", - gossh.ECHOKE: "echoke", - gossh.PENDIN: "pendin", - gossh.OPOST: "opost", - gossh.OLCUC: "olcuc", - gossh.ONLCR: "onlcr", - gossh.OCRNL: "ocrnl", - gossh.ONOCR: "onocr", - gossh.ONLRET: "onlret", - gossh.CS7: "cs7", - gossh.CS8: "cs8", - gossh.PARENB: "parenb", - gossh.PARODD: "parodd", - gossh.TTY_OP_ISPEED: "tty_op_ispeed", - gossh.TTY_OP_OSPEED: "tty_op_ospeed", + ssh.VINTR: "intr", + ssh.VQUIT: "quit", + ssh.VERASE: "erase", + ssh.VKILL: "kill", + ssh.VEOF: "eof", + ssh.VEOL: "eol", + ssh.VEOL2: "eol2", + ssh.VSTART: "start", + ssh.VSTOP: "stop", + ssh.VSUSP: "susp", + ssh.VDSUSP: "dsusp", + ssh.VREPRINT: "rprnt", + ssh.VWERASE: "werase", + ssh.VLNEXT: "lnext", + ssh.VFLUSH: "flush", + ssh.VSWTCH: "swtch", + ssh.VSTATUS: "status", + ssh.VDISCARD: "discard", + ssh.IGNPAR: "ignpar", + ssh.PARMRK: "parmrk", + ssh.INPCK: "inpck", + ssh.ISTRIP: "istrip", + ssh.INLCR: "inlcr", + ssh.IGNCR: "igncr", + ssh.ICRNL: "icrnl", + ssh.IUCLC: "iuclc", + ssh.IXON: "ixon", + ssh.IXANY: "ixany", + ssh.IXOFF: "ixoff", + ssh.IMAXBEL: "imaxbel", + ssh.IUTF8: "iutf8", + ssh.ISIG: "isig", + ssh.ICANON: "icanon", + ssh.XCASE: "xcase", + ssh.ECHO: "echo", + ssh.ECHOE: "echoe", + ssh.ECHOK: "echok", + ssh.ECHONL: "echonl", + ssh.NOFLSH: "noflsh", + ssh.TOSTOP: "tostop", + ssh.IEXTEN: "iexten", + ssh.ECHOCTL: "echoctl", + ssh.ECHOKE: "echoke", + ssh.PENDIN: "pendin", + ssh.OPOST: "opost", + ssh.OLCUC: "olcuc", + ssh.ONLCR: "onlcr", + ssh.OCRNL: "ocrnl", + ssh.ONOCR: "onocr", + ssh.ONLRET: "onlret", + ssh.CS7: "cs7", + ssh.CS8: "cs8", + ssh.PARENB: "parenb", + ssh.PARODD: "parodd", + ssh.TTY_OP_ISPEED: "tty_op_ispeed", + ssh.TTY_OP_OSPEED: "tty_op_ospeed", } // startWithPTY starts cmd with a pseudo-terminal attached to Stdin, Stdout and Stderr. @@ -1012,11 +1011,11 @@ func (ss *sshSession) startWithPTY() (ptyFile, tty *os.File, err error) { tios.Col = int(ptyReq.Window.Width) for c, v := range ptyReq.Modes { - if c == gossh.TTY_OP_ISPEED { + if c == ssh.TTY_OP_ISPEED { tios.Ispeed = int(v) continue } - if c == gossh.TTY_OP_OSPEED { + if c == ssh.TTY_OP_OSPEED { tios.Ospeed = int(v) continue } diff --git a/ssh/tailssh/incubator_plan9.go b/ssh/tailssh/incubator_plan9.go index 69112635f5c11..8d0031413e4a4 100644 --- a/ssh/tailssh/incubator_plan9.go +++ b/ssh/tailssh/incubator_plan9.go @@ -92,7 +92,7 @@ func (ss *sshSession) newIncubatorCommand(logf logger.Logf) (cmd *exec.Cmd, err "--tty-name=", // updated in-place by startWithPTY } - nm := ss.conn.srv.lb.NetMap() + nm := ss.conn.srv.lb.NetMapNoPeers() forceV1Behavior := nm.HasCap(tailcfg.NodeAttrSSHBehaviorV1) && !nm.HasCap(tailcfg.NodeAttrSSHBehaviorV2) if forceV1Behavior { incubatorArgs = append(incubatorArgs, "--force-v1-behavior") diff --git a/ssh/tailssh/privs_test.go b/ssh/tailssh/privs_test.go index f0ec66c64e581..bd483e2b48fa2 100644 --- a/ssh/tailssh/privs_test.go +++ b/ssh/tailssh/privs_test.go @@ -20,6 +20,7 @@ import ( "syscall" "testing" + "tailscale.com/tstest" "tailscale.com/types/logger" ) @@ -71,9 +72,7 @@ func TestDoDropPrivileges(t *testing.T) { os.Exit(0) } - if os.Getuid() != 0 { - t.Skip("test only works when run as root") - } + tstest.RequireRoot(t) rerunSelf := func(t *testing.T, input SubprocInput) []byte { fpath := filepath.Join(t.TempDir(), "out.json") @@ -262,12 +261,10 @@ func maybeValidUID(id int) bool { return true } - var u1 user.UnknownUserIdError - if errors.As(err, &u1) { + if _, ok := errors.AsType[user.UnknownUserIdError](err); ok { return false } - var u2 user.UnknownUserError - if errors.As(err, &u2) { + if _, ok := errors.AsType[user.UnknownUserError](err); ok { return false } @@ -281,12 +278,10 @@ func maybeValidGID(id int) bool { return true } - var u1 user.UnknownGroupIdError - if errors.As(err, &u1) { + if _, ok := errors.AsType[user.UnknownGroupIdError](err); ok { return false } - var u2 user.UnknownGroupError - if errors.As(err, &u2) { + if _, ok := errors.AsType[user.UnknownGroupError](err); ok { return false } diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index cb56f701b5e68..e01f78eb3ae50 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -14,6 +14,7 @@ import ( "errors" "fmt" "io" + "maps" "net" "net/http" "net/netip" @@ -29,7 +30,8 @@ import ( "syscall" "time" - gossh "golang.org/x/crypto/ssh" + gliderssh "github.com/tailscale/gliderssh" + "golang.org/x/crypto/ssh" "tailscale.com/envknob" "tailscale.com/feature" "tailscale.com/ipn/ipnlocal" @@ -37,7 +39,6 @@ import ( "tailscale.com/net/tsdial" "tailscale.com/sessionrecording" "tailscale.com/tailcfg" - "tailscale.com/tempfork/gliderlabs/ssh" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/netmap" @@ -53,10 +54,10 @@ var ( sshDisableForwarding = envknob.RegisterBool("TS_SSH_DISABLE_FORWARDING") sshDisablePTY = envknob.RegisterBool("TS_SSH_DISABLE_PTY") - // errTerminal is an empty gossh.PartialSuccessError (with no 'Next' + // errTerminal is an empty ssh.PartialSuccessError (with no 'Next' // authentication methods that may proceed), which results in the SSH // server immediately disconnecting the client. - errTerminal = &gossh.PartialSuccessError{} + errTerminal = &ssh.PartialSuccessError{} // hookSSHLoginSuccess is called after successful SSH authentication. // It is set by platform-specific code (e.g., auditd_linux.go). @@ -73,9 +74,9 @@ const ( // ipnLocalBackend is the subset of ipnlocal.LocalBackend that we use. // It is used for testing. type ipnLocalBackend interface { - GetSSH_HostKeys() ([]gossh.Signer, error) ShouldRunSSH() bool NetMap() *netmap.NetworkMap + NetMapNoPeers() *netmap.NetworkMap WhoIs(proto string, ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) DoNoiseRequest(req *http.Request) (*http.Response, error) Dialer() *tsdial.Dialer @@ -106,6 +107,8 @@ func (srv *server) now() time.Time { } func init() { + feature.HookGetSSHHostKeyPublicStrings.Set(getHostKeyPublicStrings) + ipnlocal.RegisterC2N("/ssh/usernames", handleC2NSSHUsernames) ipnlocal.RegisterNewSSHServer(func(logf logger.Logf, lb *ipnlocal.LocalBackend) (ipnlocal.SSHServer, error) { tsd, err := os.Executable() if err != nil { @@ -202,7 +205,7 @@ func (srv *server) OnPolicyChange() { } // conn represents a single SSH connection and its associated -// ssh.Server. +// gliderssh.Server. // // During the lifecycle of a connection, the following are called in order: // Setup and discover server info @@ -218,9 +221,9 @@ func (srv *server) OnPolicyChange() { // channels concurrently. At which point any of the following can be called // in any order. // - c.handleSessionPostSSHAuth -// - c.mayForwardLocalPortTo followed by ssh.DirectTCPIPHandler +// - c.mayForwardLocalPortTo followed by gliderssh.DirectTCPIPHandler type conn struct { - *ssh.Server + *gliderssh.Server srv *server insecureSkipTailscaleAuth bool // used by tests. @@ -232,9 +235,9 @@ type conn struct { idH string connID string // ID that's shared with control - // spac is a [gossh.ServerPreAuthConn] used for sending auth banners. + // spac is a [ssh.ServerPreAuthConn] used for sending auth banners. // Banners cannot be sent after auth completes. - spac gossh.ServerPreAuthConn + spac ssh.ServerPreAuthConn // The following fields are set during clientAuth and are used for policy // evaluation and session management. They are immutable after clientAuth @@ -278,7 +281,7 @@ func (c *conn) vlogf(format string, args ...any) { // errDenied is returned by auth callbacks when a connection is denied by the // policy. It writes the message to an auth banner and then returns an empty -// gossh.PartialSuccessError in order to stop processing authentication +// ssh.PartialSuccessError in order to stop processing authentication // attempts and immediately disconnect the client. func (c *conn) errDenied(message string) error { if message == "" { @@ -291,7 +294,7 @@ func (c *conn) errDenied(message string) error { } // errBanner writes the given message to an auth banner and then returns an -// empty gossh.PartialSuccessError in order to stop processing authentication +// empty ssh.PartialSuccessError in order to stop processing authentication // attempts and immediately disconnect the client. The contents of err is not // leaked in the auth banner, but it is logged to the server's log. func (c *conn) errBanner(message string, err error) error { @@ -306,7 +309,7 @@ func (c *conn) errBanner(message string, err error) error { // errUnexpected is returned by auth callbacks that encounter an unexpected // error, such as being unable to send an auth banner. It sends an empty -// gossh.PartialSuccessError to tell gossh.Server to stop processing +// ssh.PartialSuccessError to tell ssh.Server to stop processing // authentication attempts and instead disconnect immediately. func (c *conn) errUnexpected(err error) error { c.logf("terminal error: %s", err) @@ -317,11 +320,11 @@ func (c *conn) errUnexpected(err error) error { // // If policy evaluation fails, it returns an error. // If access is denied, it returns an error. This must always be an empty -// gossh.PartialSuccessError to prevent further authentication methods from +// ssh.PartialSuccessError to prevent further authentication methods from // being tried. -func (c *conn) clientAuth(cm gossh.ConnMetadata) (perms *gossh.Permissions, retErr error) { +func (c *conn) clientAuth(cm ssh.ConnMetadata) (perms *ssh.Permissions, retErr error) { defer func() { - if pse, ok := retErr.(*gossh.PartialSuccessError); ok { + if pse, ok := retErr.(*ssh.PartialSuccessError); ok { if pse.Next.GSSAPIWithMICConfig != nil || pse.Next.KeyboardInteractiveCallback != nil || pse.Next.PasswordCallback != nil || @@ -334,7 +337,7 @@ func (c *conn) clientAuth(cm gossh.ConnMetadata) (perms *gossh.Permissions, retE }() if c.insecureSkipTailscaleAuth { - return &gossh.Permissions{}, nil + return &ssh.Permissions{}, nil } if err := c.setInfo(cm); err != nil { @@ -382,7 +385,7 @@ func (c *conn) clientAuth(cm gossh.ConnMetadata) (perms *gossh.Permissions, retE } c.finalAction = action c.authCompleted.Store(true) - return &gossh.Permissions{}, nil + return &ssh.Permissions{}, nil case action.Reject: metricTerminalReject.Add(1) c.finalAction = action @@ -415,14 +418,14 @@ func (c *conn) clientAuth(cm gossh.ConnMetadata) (perms *gossh.Permissions, retE } } -// ServerConfig implements ssh.ServerConfigCallback. -func (c *conn) ServerConfig(ctx ssh.Context) *gossh.ServerConfig { - return &gossh.ServerConfig{ - PreAuthConnCallback: func(spac gossh.ServerPreAuthConn) { +// ServerConfig implements gliderssh.ServerConfigCallback. +func (c *conn) ServerConfig(ctx gliderssh.Context) *ssh.ServerConfig { + return &ssh.ServerConfig{ + PreAuthConnCallback: func(spac ssh.ServerPreAuthConn) { c.spac = spac }, NoClientAuth: true, // required for the NoClientAuthCallback to run - NoClientAuthCallback: func(cm gossh.ConnMetadata) (*gossh.Permissions, error) { + NoClientAuthCallback: func(cm ssh.ConnMetadata) (*ssh.Permissions, error) { // First perform client authentication, which can potentially // involve multiple steps (for example prompting user to log in to // Tailscale admin panel to confirm identity). @@ -436,10 +439,10 @@ func (c *conn) ServerConfig(ctx ssh.Context) *gossh.ServerConfig { // specify a username ending in "+password" to force password auth. // The actual value of the password doesn't matter. if strings.HasSuffix(cm.User(), forcePasswordSuffix) { - return nil, &gossh.PartialSuccessError{ - Next: gossh.ServerAuthCallbacks{ - PasswordCallback: func(_ gossh.ConnMetadata, password []byte) (*gossh.Permissions, error) { - return &gossh.Permissions{}, nil + return nil, &ssh.PartialSuccessError{ + Next: ssh.ServerAuthCallbacks{ + PasswordCallback: func(_ ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { + return &ssh.Permissions{}, nil }, }, } @@ -447,14 +450,14 @@ func (c *conn) ServerConfig(ctx ssh.Context) *gossh.ServerConfig { return perms, nil }, - PasswordCallback: func(cm gossh.ConnMetadata, pword []byte) (*gossh.Permissions, error) { + PasswordCallback: func(cm ssh.ConnMetadata, pword []byte) (*ssh.Permissions, error) { // Some clients don't request 'none' authentication. Instead, they // immediately supply a password. We humor them by accepting the // password, but authenticate as usual, ignoring the actual value of // the password. return c.clientAuth(cm) }, - PublicKeyCallback: func(cm gossh.ConnMetadata, key gossh.PublicKey) (*gossh.Permissions, error) { + PublicKeyCallback: func(cm ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { // Some clients don't request 'none' authentication. Instead, they // immediately supply a public key. We humor them by accepting the // key, but authenticate as usual, ignoring the actual content of @@ -477,39 +480,38 @@ func (srv *server) newConn() (*conn, error) { c := &conn{srv: srv} now := srv.now() c.connID = fmt.Sprintf("ssh-conn-%s-%02x", now.UTC().Format("20060102T150405"), randBytes(5)) - fwdHandler := &ssh.ForwardedTCPHandler{} - c.Server = &ssh.Server{ + fwdHandler := &gliderssh.ForwardedTCPHandler{} + streamLocalFwdHandler := &gliderssh.ForwardedUnixHandler{} + c.Server = &gliderssh.Server{ Version: "Tailscale", ServerConfigCallback: c.ServerConfig, Handler: c.handleSessionPostSSHAuth, LocalPortForwardingCallback: c.mayForwardLocalPortTo, ReversePortForwardingCallback: c.mayReversePortForwardTo, - SubsystemHandlers: map[string]ssh.SubsystemHandler{ + + LocalUnixForwardingCallback: c.mayForwardLocalUnixTo, + ReverseUnixForwardingCallback: c.mayReverseUnixForwardTo, + + SubsystemHandlers: map[string]gliderssh.SubsystemHandler{ "sftp": c.handleSessionPostSSHAuth, }, - // Note: the direct-tcpip channel handler and LocalPortForwardingCallback - // only adds support for forwarding ports from the local machine. - // TODO(maisem/bradfitz): add remote port forwarding support. - ChannelHandlers: map[string]ssh.ChannelHandler{ - "direct-tcpip": ssh.DirectTCPIPHandler, + ChannelHandlers: map[string]gliderssh.ChannelHandler{ + "direct-tcpip": gliderssh.DirectTCPIPHandler, + "direct-streamlocal@openssh.com": gliderssh.DirectStreamLocalHandler, }, - RequestHandlers: map[string]ssh.RequestHandler{ - "tcpip-forward": fwdHandler.HandleSSHRequest, - "cancel-tcpip-forward": fwdHandler.HandleSSHRequest, + RequestHandlers: map[string]gliderssh.RequestHandler{ + "tcpip-forward": fwdHandler.HandleSSHRequest, + "cancel-tcpip-forward": fwdHandler.HandleSSHRequest, + "streamlocal-forward@openssh.com": streamLocalFwdHandler.HandleSSHRequest, + "cancel-streamlocal-forward@openssh.com": streamLocalFwdHandler.HandleSSHRequest, }, } ss := c.Server - for k, v := range ssh.DefaultRequestHandlers { - ss.RequestHandlers[k] = v - } - for k, v := range ssh.DefaultChannelHandlers { - ss.ChannelHandlers[k] = v - } - for k, v := range ssh.DefaultSubsystemHandlers { - ss.SubsystemHandlers[k] = v - } - keys, err := srv.lb.GetSSH_HostKeys() + maps.Copy(ss.RequestHandlers, gliderssh.DefaultRequestHandlers) + maps.Copy(ss.ChannelHandlers, gliderssh.DefaultChannelHandlers) + maps.Copy(ss.SubsystemHandlers, gliderssh.DefaultSubsystemHandlers) + keys, err := getHostKeys(srv.lb.TailscaleVarRoot(), srv.logf) if err != nil { return nil, err } @@ -522,7 +524,7 @@ func (srv *server) newConn() (*conn, error) { // mayReversePortPortForwardTo reports whether the ctx should be allowed to port forward // to the specified host and port. // TODO(bradfitz/maisem): should we have more checks on host/port? -func (c *conn) mayReversePortForwardTo(ctx ssh.Context, destinationHost string, destinationPort uint32) bool { +func (c *conn) mayReversePortForwardTo(ctx gliderssh.Context, destinationHost string, destinationPort uint32) bool { if sshDisableForwarding() { return false } @@ -536,7 +538,7 @@ func (c *conn) mayReversePortForwardTo(ctx ssh.Context, destinationHost string, // mayForwardLocalPortTo reports whether the ctx should be allowed to port forward // to the specified host and port. // TODO(bradfitz/maisem): should we have more checks on host/port? -func (c *conn) mayForwardLocalPortTo(ctx ssh.Context, destinationHost string, destinationPort uint32) bool { +func (c *conn) mayForwardLocalPortTo(ctx gliderssh.Context, destinationHost string, destinationPort uint32) bool { if sshDisableForwarding() { return false } @@ -547,6 +549,48 @@ func (c *conn) mayForwardLocalPortTo(ctx ssh.Context, destinationHost string, de return false } +// mayForwardLocalUnixTo is the server-side handler for +// direct-streamlocal@openssh.com (SSH -L with Unix sockets). It returns a +// connection to the specified Unix domain socket path if forwarding is +// permitted, or an error if not. +func (c *conn) mayForwardLocalUnixTo(ctx gliderssh.Context, socketPath string) (net.Conn, error) { + if sshDisableForwarding() { + return nil, gliderssh.ErrRejected + } + if c.finalAction != nil && c.finalAction.AllowLocalPortForwarding { + metricLocalPortForward.Add(1) + cb := gliderssh.NewLocalUnixForwardingCallback(c.unixForwardingOptions()) + return cb(ctx, socketPath) + } + return nil, gliderssh.ErrRejected +} + +// mayReverseUnixForwardTo is the server-side handler for +// streamlocal-forward@openssh.com (SSH -R with Unix sockets). It returns a +// listener for the specified Unix domain socket path if reverse forwarding is +// permitted, or an error if not. +func (c *conn) mayReverseUnixForwardTo(ctx gliderssh.Context, socketPath string) (net.Listener, error) { + if sshDisableForwarding() { + return nil, gliderssh.ErrRejected + } + if c.finalAction != nil && c.finalAction.AllowRemotePortForwarding { + metricRemotePortForward.Add(1) + cb := gliderssh.NewReverseUnixForwardingCallback(c.unixForwardingOptions()) + return cb(ctx, socketPath) + } + return nil, gliderssh.ErrRejected +} + +// unixForwardingOptions returns the Unix forwarding options scoped to the +// authenticated local user. Socket paths are restricted to the user's home +// directory, /tmp, and /run/user/. +func (c *conn) unixForwardingOptions() gliderssh.UnixForwardingOptions { + return gliderssh.UnixForwardingOptions{ + AllowedDirectories: gliderssh.UserSocketDirectories(c.localUser.HomeDir, c.localUser.Uid), + BindUnlink: true, + } +} + // sshPolicy returns the SSHPolicy for current node. // If there is no SSHPolicy in the netmap, it returns a debugPolicy // if one is defined. @@ -555,7 +599,7 @@ func (c *conn) sshPolicy() (_ *tailcfg.SSHPolicy, ok bool) { if !lb.ShouldRunSSH() { return nil, false } - nm := lb.NetMap() + nm := lb.NetMapNoPeers() if nm == nil { return nil, false } @@ -594,7 +638,7 @@ func toIPPort(a net.Addr) (ipp netip.AddrPort) { // connInfo populates the sshConnInfo from the provided arguments, // validating only that they represent a known Tailscale identity. -func (c *conn) setInfo(cm gossh.ConnMetadata) error { +func (c *conn) setInfo(cm ssh.ConnMetadata) error { if c.info != nil { return nil } @@ -644,7 +688,7 @@ func (c *conn) evaluatePolicy() (_ *tailcfg.SSHAction, localUser string, acceptE // handleSessionPostSSHAuth runs an SSH session after the SSH-level authentication, // but not necessarily before all the Tailscale-level extra verification has // completed. It also handles SFTP requests. -func (c *conn) handleSessionPostSSHAuth(s ssh.Session) { +func (c *conn) handleSessionPostSSHAuth(s gliderssh.Session) { // Do this check after auth, but before starting the session. switch s.Subsystem() { case "sftp": @@ -674,7 +718,7 @@ func (c *conn) handleSessionPostSSHAuth(s ssh.Session) { } func (c *conn) expandDelegateURLLocked(actionURL string) string { - nm := c.srv.lb.NetMap() + nm := c.srv.lb.NetMapNoPeers() ci := c.info lu := c.localUser var dstNodeID string @@ -693,7 +737,7 @@ func (c *conn) expandDelegateURLLocked(actionURL string) string { // sshSession is an accepted Tailscale SSH session. type sshSession struct { - ssh.Session + gliderssh.Session sharedID string // ID that's shared with control logf logger.Logf @@ -706,8 +750,8 @@ type sshSession struct { cmd *exec.Cmd wrStdin io.WriteCloser rdStdout io.ReadCloser - rdStderr io.ReadCloser // rdStderr is nil for pty sessions - ptyReq *ssh.Pty // non-nil for pty sessions + rdStderr io.ReadCloser // rdStderr is nil for pty sessions + ptyReq *gliderssh.Pty // non-nil for pty sessions // childPipes is a list of pipes that need to be closed when the process exits. // For pty sessions, this is the tty fd. @@ -717,6 +761,12 @@ type sshSession struct { // We use this sync.Once to ensure that we only terminate the process once, // either it exits itself or is terminated exitOnce sync.Once + + // exitHandled is closed when killProcessOnContextDone finishes writing any + // termination message to the client. run() waits on this before calling + // ss.Exit to ensure the message is flushed before the SSH channel is torn + // down. It is initialized by run() before starting killProcessOnContextDone. + exitHandled chan struct{} } func (ss *sshSession) vlogf(format string, args ...any) { @@ -725,7 +775,7 @@ func (ss *sshSession) vlogf(format string, args ...any) { } } -func (c *conn) newSSHSession(s ssh.Session) *sshSession { +func (c *conn) newSSHSession(s gliderssh.Session) *sshSession { sharedID := fmt.Sprintf("sess-%s-%02x", c.srv.now().UTC().Format("20060102T150405"), randBytes(5)) c.logf("starting session: %v", sharedID) ctx, cancel := context.WithCancelCause(s.Context()) @@ -812,6 +862,7 @@ func (c *conn) fetchSSHAction(ctx context.Context, url string) (*tailcfg.SSHActi // killProcessOnContextDone waits for ss.ctx to be done and kills the process, // unless the process has already exited. func (ss *sshSession) killProcessOnContextDone() { + defer close(ss.exitHandled) <-ss.ctx.Done() // Either the process has already exited, in which case this does nothing. // Or, the process is still running in which case this will kill it. @@ -859,10 +910,10 @@ func (c *conn) detachSession(ss *sshSession) { var errSessionDone = errors.New("session is done") // handleSSHAgentForwarding starts a Unix socket listener and in the background -// forwards agent connections between the listener and the ssh.Session. +// forwards agent connections between the listener and the gliderssh.Session. // On success, it assigns ss.agentListener. -func (ss *sshSession) handleSSHAgentForwarding(s ssh.Session, lu *userMeta) error { - if !ssh.AgentRequested(ss) || !ss.conn.finalAction.AllowAgentForwarding { +func (ss *sshSession) handleSSHAgentForwarding(s gliderssh.Session, lu *userMeta) error { + if !gliderssh.AgentRequested(ss) || !ss.conn.finalAction.AllowAgentForwarding { return nil } if sshDisableForwarding() { @@ -872,7 +923,7 @@ func (ss *sshSession) handleSSHAgentForwarding(s ssh.Session, lu *userMeta) erro return nil } ss.logf("ssh: agent forwarding requested") - ln, err := ssh.NewAgentListener() + ln, err := gliderssh.NewAgentListener() if err != nil { return err } @@ -904,7 +955,7 @@ func (ss *sshSession) handleSSHAgentForwarding(s ssh.Session, lu *userMeta) erro return err } - go ssh.ForwardAgentConnections(ln, s) + go gliderssh.ForwardAgentConnections(ln, s) ss.agentListener = ln return nil } @@ -964,8 +1015,7 @@ func (ss *sshSession) run() { var err error rec, err = ss.startNewRecording() if err != nil { - var uve userVisibleError - if errors.As(err, &uve) { + if uve, ok := errors.AsType[userVisibleError](err); ok { fmt.Fprintf(ss, "%s\r\n", uve.SSHTerminationMessage()) } else { fmt.Fprintf(ss, "can't start new recording\r\n") @@ -985,15 +1035,17 @@ func (ss *sshSession) run() { if err != nil { logf("start failed: %v", err.Error()) if errors.Is(err, context.Canceled) { - err := context.Cause(ss.ctx) - var uve userVisibleError - if errors.As(err, &uve) { - fmt.Fprintf(ss, "%s\r\n", uve) + cause := context.Cause(ss.ctx) + if serr, ok := cause.(SSHTerminationError); ok { + if msg := serr.SSHTerminationMessage(); msg != "" { + io.WriteString(ss.Stderr(), "\r\n\r\n"+msg+"\r\n\r\n") + } } } ss.Exit(1) return } + ss.exitHandled = make(chan struct{}) go ss.killProcessOnContextDone() var processDone atomic.Bool @@ -1044,6 +1096,15 @@ func (ss *sshSession) run() { err = ss.cmd.Wait() processDone.Store(true) + if ss.ctx.Err() != nil { + // Context was canceled (e.g., recording upload failure). + // Wait for killProcessOnContextDone to finish writing any + // termination message before we proceed. This must happen + // before closeAll and CloseWrite so the SSH channel is + // still writable. + <-ss.exitHandled + } + // This will either make the SSH Termination goroutine be a no-op, // or itself will be a no-op because the process was killed by the // aforementioned goroutine. @@ -1056,6 +1117,7 @@ func (ss *sshSession) run() { select { case <-outputDone: case <-ss.ctx.Done(): + <-ss.exitHandled } if err == nil { @@ -1274,7 +1336,7 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) { } } - var w ssh.Window + var w gliderssh.Window if ptyReq, _, isPtyReq := ss.Pty(); isPtyReq { w = ptyReq.Window } diff --git a/ssh/tailssh/tailssh_integration_test.go b/ssh/tailssh/tailssh_integration_test.go index 1135bebbc2a5b..7b70a6d512b3b 100644 --- a/ssh/tailssh/tailssh_integration_test.go +++ b/ssh/tailssh/tailssh_integration_test.go @@ -6,7 +6,6 @@ package tailssh import ( - "bufio" "bytes" "context" "crypto/rand" @@ -31,12 +30,11 @@ import ( "github.com/bramvdbogaerde/go-scp" "github.com/google/go-cmp/cmp" "github.com/pkg/sftp" + gliderssh "github.com/tailscale/gliderssh" "golang.org/x/crypto/ssh" - gossh "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" "tailscale.com/net/tsdial" "tailscale.com/tailcfg" - glider "tailscale.com/tempfork/gliderlabs/ssh" "tailscale.com/types/key" "tailscale.com/types/netmap" "tailscale.com/util/set" @@ -53,50 +51,41 @@ import ( // - User "testuser" exists // - "testuser" is in groups "groupone" and "grouptwo" +// testVarRoot is a temp directory used as the TailscaleVarRoot for +// host key generation during integration tests. The test containers +// don't have system host keys (/etc/ssh/ssh_host_*_key) since they +// only install openssh-client, so getHostKeys needs a valid var root +// to generate keys into. +var testVarRoot string + func TestMain(m *testing.M) { + debugTest.Store(true) + // Create our log file. - file, err := os.OpenFile("/tmp/tailscalessh.log", os.O_CREATE|os.O_WRONLY, 0666) - if err != nil { + if err := os.WriteFile("/tmp/tailscalessh.log", nil, 0666); err != nil { log.Fatal(err) } - file.Close() - // Tail our log file. - cmd := exec.Command("tail", "-F", "/tmp/tailscalessh.log") - - r, err := cmd.StdoutPipe() + // Create a temp directory for SSH host keys. + var err error + testVarRoot, err = os.MkdirTemp("", "tailssh-test-var") if err != nil { - return + log.Fatal(err) } - scanner := bufio.NewScanner(r) - go func() { - for scanner.Scan() { - line := scanner.Text() - log.Println(line) - } - }() + code := m.Run() - err = cmd.Start() - if err != nil { - return + os.RemoveAll(testVarRoot) + + // Print any log output from the incubator subprocesses. + if b, err := os.ReadFile("/tmp/tailscalessh.log"); err == nil && len(b) > 0 { + log.Print(string(b)) } - defer func() { - // tail -f has a default sleep interval of 1 second, so it takes a - // moment for it to finish reading our log file after we've terminated. - // So, wait a bit to let it catch up. - time.Sleep(2 * time.Second) - }() - m.Run() + os.Exit(code) } func TestIntegrationSSH(t *testing.T) { - debugTest.Store(true) - t.Cleanup(func() { - debugTest.Store(false) - }) - homeDir := "/home/testuser" if runtime.GOOS == "darwin" { homeDir = "/Users/testuser" @@ -202,11 +191,6 @@ func TestIntegrationSSH(t *testing.T) { } func TestIntegrationSFTP(t *testing.T) { - debugTest.Store(true) - t.Cleanup(func() { - debugTest.Store(false) - }) - for _, forceV1Behavior := range []bool{false, true} { name := "v2" if forceV1Behavior { @@ -263,11 +247,6 @@ func TestIntegrationSFTP(t *testing.T) { } func TestIntegrationSCP(t *testing.T) { - debugTest.Store(true) - t.Cleanup(func() { - debugTest.Store(false) - }) - for _, forceV1Behavior := range []bool{false, true} { name := "v2" if forceV1Behavior { @@ -321,11 +300,6 @@ func TestIntegrationSCP(t *testing.T) { } func TestSSHAgentForwarding(t *testing.T) { - debugTest.Store(true) - t.Cleanup(func() { - debugTest.Store(false) - }) - // Create a client SSH key tmpDir, err := os.MkdirTemp("", "") if err != nil { @@ -347,11 +321,11 @@ func TestSSHAgentForwarding(t *testing.T) { }) // Run an SSH server that accepts connections from that client SSH key. - gs := glider.Server{ - Handler: func(s glider.Session) { + gs := gliderssh.Server{ + Handler: func(s gliderssh.Session) { io.WriteString(s, "Hello world\n") }, - PublicKeyHandler: func(ctx glider.Context, key glider.PublicKey) error { + PublicKeyHandler: func(ctx gliderssh.Context, key gliderssh.PublicKey) error { // Note - this is not meant to be cryptographically secure, it's // just checking that SSH agent forwarding is forwarding the right // key. @@ -415,11 +389,6 @@ func TestSSHAgentForwarding(t *testing.T) { // request 'none' auth and instead immediately authenticate with a public key // or password. func TestIntegrationParamiko(t *testing.T) { - debugTest.Store(true) - t.Cleanup(func() { - debugTest.Store(false) - }) - addr := testServer(t, "testuser", true, false) host, port, err := net.SplitHostPort(addr) if err != nil { @@ -451,6 +420,233 @@ client.exec_command('pwd') } } +// TestLocalUnixForwarding tests direct-streamlocal@openssh.com, which is what +// podman remote (issue #12409) and VSCode Remote (issue #5295) use to reach +// Unix domain sockets on the remote host through SSH. The client opens a +// channel to a Unix socket path on the server, and data is proxied through. +func TestLocalUnixForwarding(t *testing.T) { + debugTest.Store(true) + t.Cleanup(func() { + debugTest.Store(false) + }) + + // Create a Unix socket server in /tmp that simulates a service like + // podman's API socket at /run/user//podman/podman.sock. + socketDir, err := os.MkdirTemp("", "tailssh-test-") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { os.RemoveAll(socketDir) }) + socketPath := filepath.Join(socketDir, "test-service.sock") + + ul, err := net.Listen("unix", socketPath) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { ul.Close() }) + + // The service echoes back whatever it receives, like an API server would. + go func() { + for { + conn, err := ul.Accept() + if err != nil { + return + } + go func() { + defer conn.Close() + io.Copy(conn, conn) + }() + } + }() + + // Start Tailscale SSH server with local port forwarding enabled. + addr := testServerWithOpts(t, testServerOpts{ + username: "testuser", + allowLocalPortForwarding: true, + }) + + // Connect to the Tailscale SSH server. + cl, err := ssh.Dial("tcp", addr, &ssh.ClientConfig{ + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + }) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { cl.Close() }) + + // Open a direct-streamlocal@openssh.com channel to the Unix socket, + // exactly as podman remote does. + conn, err := cl.Dial("unix", socketPath) + if err != nil { + t.Fatalf("failed to dial unix socket through SSH: %s", err) + } + defer conn.Close() + + // Send data through the tunnel and verify it echoes back. + want := "GET /_ping HTTP/1.1\r\nHost: d\r\n\r\n" + _, err = io.WriteString(conn, want) + if err != nil { + t.Fatalf("failed to write through tunnel: %s", err) + } + + got := make([]byte, len(want)) + _, err = io.ReadFull(conn, got) + if err != nil { + t.Fatalf("failed to read through tunnel: %s", err) + } + if string(got) != want { + t.Errorf("got %q, want %q", got, want) + } +} + +// TestReverseUnixForwarding tests streamlocal-forward@openssh.com, which tools +// like VSCode Remote and Zed use to create Unix domain sockets on the remote +// host that forward connections back to the client through SSH. +func TestReverseUnixForwarding(t *testing.T) { + debugTest.Store(true) + t.Cleanup(func() { + debugTest.Store(false) + }) + + // Start Tailscale SSH server with remote port forwarding enabled. + addr := testServerWithOpts(t, testServerOpts{ + username: "testuser", + allowRemotePortForwarding: true, + }) + + // Connect to the Tailscale SSH server. + cl, err := ssh.Dial("tcp", addr, &ssh.ClientConfig{ + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + }) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { cl.Close() }) + + // Request reverse forwarding -- the server creates a Unix socket and + // forwards incoming connections back through the SSH tunnel. + socketDir, err := os.MkdirTemp("", "tailssh-test-") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { os.RemoveAll(socketDir) }) + remoteSocketPath := filepath.Join(socketDir, "reverse.sock") + + ln, err := cl.ListenUnix(remoteSocketPath) + if err != nil { + t.Fatalf("failed to request reverse unix forwarding: %s", err) + } + t.Cleanup(func() { ln.Close() }) + + // Verify the socket file was created on the server side. + if _, err := os.Stat(remoteSocketPath); err != nil { + t.Fatalf("reverse forwarded socket not created: %s", err) + } + + // Accept a connection from the tunnel (client side) and write data. + want := "hello from reverse tunnel" + go func() { + conn, err := ln.Accept() + if err != nil { + return + } + defer conn.Close() + io.WriteString(conn, want) + }() + + // Connect directly to the socket on the server side, simulating a + // local process connecting to the VSCode/Zed IPC socket. + conn, err := net.Dial("unix", remoteSocketPath) + if err != nil { + t.Fatalf("failed to connect to reverse forwarded socket: %s", err) + } + defer conn.Close() + + got, err := io.ReadAll(conn) + if err != nil { + t.Fatalf("failed to read from reverse forwarded socket: %s", err) + } + if string(got) != want { + t.Errorf("got %q, want %q", got, want) + } +} + +// TestUnixForwardingDenied verifies that Unix socket forwarding is rejected +// when the SSH policy does not permit port forwarding. +func TestUnixForwardingDenied(t *testing.T) { + debugTest.Store(true) + t.Cleanup(func() { + debugTest.Store(false) + }) + + // Start server with forwarding disabled (the default policy). + addr := testServerWithOpts(t, testServerOpts{ + username: "testuser", + }) + + cl, err := ssh.Dial("tcp", addr, &ssh.ClientConfig{ + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + }) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { cl.Close() }) + + // Direct Unix socket forwarding should be rejected. + _, err = cl.Dial("unix", "/tmp/anything.sock") + if err == nil { + t.Error("expected direct unix forwarding to be rejected, but it succeeded") + } + + // Reverse Unix socket forwarding should also be rejected. + socketDir, err := os.MkdirTemp("", "tailssh-test-") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { os.RemoveAll(socketDir) }) + + _, err = cl.ListenUnix(filepath.Join(socketDir, "denied.sock")) + if err == nil { + t.Error("expected reverse unix forwarding to be rejected, but it succeeded") + } +} + +// TestUnixForwardingPathRestriction verifies that socket paths outside the +// allowed directories (home, /tmp, /run/user/) are rejected even when +// forwarding is permitted by policy. +func TestUnixForwardingPathRestriction(t *testing.T) { + debugTest.Store(true) + t.Cleanup(func() { + debugTest.Store(false) + }) + + addr := testServerWithOpts(t, testServerOpts{ + username: "testuser", + allowLocalPortForwarding: true, + allowRemotePortForwarding: true, + }) + + cl, err := ssh.Dial("tcp", addr, &ssh.ClientConfig{ + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + }) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { cl.Close() }) + + // Paths outside allowed directories should be rejected. + restrictedPaths := []string{ + "/var/run/docker.sock", + "/etc/evil.sock", + } + for _, path := range restrictedPaths { + _, err := cl.Dial("unix", path) + if err == nil { + t.Errorf("expected direct forwarding to %q to be rejected, but it succeeded", path) + } + } +} + func fallbackToSUAvailable() bool { if runtime.GOOS != "linux" { return false @@ -496,26 +692,34 @@ func (s *session) run(t *testing.T, cmdString string, shell bool) string { func (s *session) read() string { ch := make(chan []byte) go func() { + defer close(ch) for { b := make([]byte, 1) n, err := s.stdout.Read(b) if n > 0 { ch <- b } - if err == io.EOF { + if err != nil { return } } }() // Read first byte in blocking fashion. - _got := <-ch + b, ok := <-ch + if !ok { + return "" + } + _got := b - // Read subsequent bytes in non-blocking fashion. + // Read subsequent bytes until EOF or silence. readLoop: for { select { - case b := <-ch: + case b, ok := <-ch: + if !ok { + break readLoop + } _got = append(_got, b...) case <-time.After(1 * time.Second): break readLoop @@ -569,6 +773,47 @@ func testServer(t *testing.T, username string, forceV1Behavior bool, allowSendEn return l.Addr().String() } +type testServerOpts struct { + username string + forceV1Behavior bool + allowSendEnv bool + allowLocalPortForwarding bool + allowRemotePortForwarding bool +} + +func testServerWithOpts(t *testing.T, opts testServerOpts) string { + t.Helper() + srv := &server{ + lb: &testBackend{ + localUser: opts.username, + forceV1Behavior: opts.forceV1Behavior, + allowSendEnv: opts.allowSendEnv, + allowLocalPortForwarding: opts.allowLocalPortForwarding, + allowRemotePortForwarding: opts.allowRemotePortForwarding, + }, + logf: log.Printf, + tailscaledPath: os.Getenv("TAILSCALED_PATH"), + timeNow: time.Now, + } + + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { l.Close() }) + + go func() { + for { + conn, err := l.Accept() + if err == nil { + go srv.HandleSSHConn(&addressFakingConn{conn}) + } + } + }() + + return l.Addr().String() +} + func testSession(t *testing.T, forceV1Behavior bool, allowSendEnv bool, sendEnv map[string]string) *session { cl := testClient(t, forceV1Behavior, allowSendEnv) return testSessionFor(t, cl, sendEnv) @@ -626,31 +871,11 @@ func generateClientKey(t *testing.T, privateKeyFile string) (ssh.Signer, *rsa.Pr // testBackend implements ipnLocalBackend type testBackend struct { - localUser string - forceV1Behavior bool - allowSendEnv bool -} - -func (tb *testBackend) GetSSH_HostKeys() ([]gossh.Signer, error) { - var result []gossh.Signer - var priv any - var err error - const keySize = 2048 - priv, err = rsa.GenerateKey(rand.Reader, keySize) - if err != nil { - return nil, err - } - mk, err := x509.MarshalPKCS8PrivateKey(priv) - if err != nil { - return nil, err - } - hostKey := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: mk}) - signer, err := gossh.ParsePrivateKey(hostKey) - if err != nil { - return nil, err - } - result = append(result, signer) - return result, nil + localUser string + forceV1Behavior bool + allowSendEnv bool + allowLocalPortForwarding bool + allowRemotePortForwarding bool } func (tb *testBackend) ShouldRunSSH() bool { @@ -670,9 +895,14 @@ func (tb *testBackend) NetMap() *netmap.NetworkMap { Rules: []*tailcfg.SSHRule{ { Principals: []*tailcfg.SSHPrincipal{{Any: true}}, - Action: &tailcfg.SSHAction{Accept: true, AllowAgentForwarding: true}, - SSHUsers: map[string]string{"*": tb.localUser}, - AcceptEnv: []string{"GIT_*", "EXACT_MATCH", "TEST?NG"}, + Action: &tailcfg.SSHAction{ + Accept: true, + AllowAgentForwarding: true, + AllowLocalPortForwarding: tb.allowLocalPortForwarding, + AllowRemotePortForwarding: tb.allowRemotePortForwarding, + }, + SSHUsers: map[string]string{"*": tb.localUser}, + AcceptEnv: []string{"GIT_*", "EXACT_MATCH", "TEST?NG"}, }, }, }, @@ -680,6 +910,8 @@ func (tb *testBackend) NetMap() *netmap.NetworkMap { } } +func (tb *testBackend) NetMapNoPeers() *netmap.NetworkMap { return tb.NetMap() } + func (tb *testBackend) WhoIs(_ string, ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) { return (&tailcfg.Node{}).View(), tailcfg.UserProfile{ LoginName: tb.localUser + "@example.com", @@ -695,7 +927,7 @@ func (tb *testBackend) Dialer() *tsdial.Dialer { } func (tb *testBackend) TailscaleVarRoot() string { - return "" + return testVarRoot } func (tb *testBackend) NodeKey() key.NodePublic { diff --git a/ssh/tailssh/tailssh_test.go b/ssh/tailssh/tailssh_test.go index 6d9d859a22d91..04c9cd2f51d9e 100644 --- a/ssh/tailssh/tailssh_test.go +++ b/ssh/tailssh/tailssh_test.go @@ -9,7 +9,6 @@ import ( "bytes" "context" "crypto/ecdsa" - "crypto/ed25519" "crypto/elliptic" "crypto/rand" "encoding/json" @@ -34,7 +33,7 @@ import ( "testing/synctest" "time" - gossh "golang.org/x/crypto/ssh" + gliderssh "github.com/tailscale/gliderssh" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" "tailscale.com/cmd/testwrapper/flakytest" @@ -44,14 +43,12 @@ import ( "tailscale.com/net/tsdial" "tailscale.com/sessionrecording" "tailscale.com/tailcfg" - "tailscale.com/tempfork/gliderlabs/ssh" testssh "tailscale.com/tempfork/sshtest/ssh" "tailscale.com/tsd" "tailscale.com/tstest" "tailscale.com/types/key" "tailscale.com/types/logid" "tailscale.com/types/netmap" - "tailscale.com/types/ptr" "tailscale.com/util/cibuild" "tailscale.com/util/lineiter" "tailscale.com/util/must" @@ -96,7 +93,7 @@ func TestMatchRule(t *testing.T) { name: "expired", rule: &tailcfg.SSHRule{ Action: someAction, - RuleExpires: ptr.To(time.Unix(100, 0)), + RuleExpires: new(time.Unix(100, 0)), }, ci: &sshConnInfo{}, wantErr: errRuleExpired, @@ -382,6 +379,7 @@ func TestEvalSSHPolicy(t *testing.T) { type localState struct { sshEnabled bool matchingRule *tailcfg.SSHRule + varRoot string // if empty, TailscaleVarRoot returns "" // serverActions is a map of the action name to the action. // It is served for paths like https://unused/ssh-action/. @@ -389,31 +387,19 @@ type localState struct { serverActions map[string]*tailcfg.SSHAction } -var ( - currentUser = os.Getenv("USER") // Use the current user for the test. - testSigner gossh.Signer - testSignerOnce sync.Once -) +var currentUser = func() string { + // Prefer user.Current because the USER env var is not set in + // some environments (e.g. the golang:latest container used by CI). + if u, err := user.Current(); err == nil { + return u.Username + } + return os.Getenv("USER") +}() func (ts *localState) Dialer() *tsdial.Dialer { return &tsdial.Dialer{} } -func (ts *localState) GetSSH_HostKeys() ([]gossh.Signer, error) { - testSignerOnce.Do(func() { - _, priv, err := ed25519.GenerateKey(rand.Reader) - if err != nil { - panic(err) - } - s, err := gossh.NewSignerFromSigner(priv) - if err != nil { - panic(err) - } - testSigner = s - }) - return []gossh.Signer{testSigner}, nil -} - func (ts *localState) ShouldRunSSH() bool { return ts.sshEnabled } @@ -436,6 +422,8 @@ func (ts *localState) NetMap() *netmap.NetworkMap { } } +func (ts *localState) NetMapNoPeers() *netmap.NetworkMap { return ts.NetMap() } + func (ts *localState) WhoIs(proto string, ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) { if proto != "tcp" { return tailcfg.NodeView{}, tailcfg.UserProfile{}, false @@ -469,7 +457,7 @@ func (ts *localState) DoNoiseRequest(req *http.Request) (*http.Response, error) } func (ts *localState) TailscaleVarRoot() string { - return "" + return ts.varRoot } func (ts *localState) NodeKey() key.NodePublic { @@ -491,14 +479,12 @@ func newSSHRule(action *tailcfg.SSHAction) *tailcfg.SSHRule { } func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) { - flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/7707") - + if runtime.GOOS == "darwin" { + flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/7707") + } if runtime.GOOS != "linux" && runtime.GOOS != "darwin" { t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS) } - if runtime.GOOS == "darwin" && cibuild.On() { - t.Skipf("this fails on CI on macOS; see https://github.com/tailscale/tailscale/issues/7707") - } var handler http.HandlerFunc recordingServer := mockRecordingServer(t, func(w http.ResponseWriter, r *http.Request) { @@ -509,6 +495,7 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) { logf: tstest.WhileTestRunningLogger(t), lb: &localState{ sshEnabled: true, + varRoot: t.TempDir(), matchingRule: newSSHRule( &tailcfg.SSHAction{ Accept: true, @@ -572,9 +559,7 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) { tstest.Replace(t, &handler, tt.handler) sc, dc := memnet.NewTCPConn(src, dst, 1024) var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() + wg.Go(func() { c, chans, reqs, err := testssh.NewClientConn(sc, sc.RemoteAddr().String(), cfg) if err != nil { t.Errorf("client: %v", err) @@ -604,7 +589,7 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) { t.Errorf("client output must not contain %q", x) } } - }() + }) if err := s.HandleSSHConn(dc); err != nil { t.Errorf("unexpected error: %v", err) } @@ -639,6 +624,7 @@ func TestMultipleRecorders(t *testing.T) { logf: tstest.WhileTestRunningLogger(t), lb: &localState{ sshEnabled: true, + varRoot: t.TempDir(), matchingRule: newSSHRule( &tailcfg.SSHAction{ Accept: true, @@ -667,9 +653,7 @@ func TestMultipleRecorders(t *testing.T) { } var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() + wg.Go(func() { c, chans, reqs, err := testssh.NewClientConn(sc, sc.RemoteAddr().String(), cfg) if err != nil { t.Errorf("client: %v", err) @@ -691,7 +675,7 @@ func TestMultipleRecorders(t *testing.T) { if string(out) != "Ran echo!\n" { t.Errorf("client: unexpected output: %q", out) } - }() + }) if err := s.HandleSSHConn(dc); err != nil { t.Errorf("unexpected error: %v", err) } @@ -714,9 +698,9 @@ func TestSSHRecordingNonInteractive(t *testing.T) { t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS) } var recording []byte - ctx, cancel := context.WithTimeout(context.Background(), time.Second) + done := make(chan struct{}) recordingServer := mockRecordingServer(t, func(w http.ResponseWriter, r *http.Request) { - defer cancel() + defer close(done) w.WriteHeader(http.StatusOK) w.(http.Flusher).Flush() @@ -732,6 +716,7 @@ func TestSSHRecordingNonInteractive(t *testing.T) { logf: tstest.WhileTestRunningLogger(t), lb: &localState{ sshEnabled: true, + varRoot: t.TempDir(), matchingRule: newSSHRule( &tailcfg.SSHAction{ Accept: true, @@ -758,9 +743,7 @@ func TestSSHRecordingNonInteractive(t *testing.T) { } var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() + wg.Go(func() { c, chans, reqs, err := testssh.NewClientConn(sc, sc.RemoteAddr().String(), cfg) if err != nil { t.Errorf("client: %v", err) @@ -779,13 +762,17 @@ func TestSSHRecordingNonInteractive(t *testing.T) { if err != nil { t.Errorf("client: %v", err) } - }() + }) if err := s.HandleSSHConn(dc); err != nil { t.Errorf("unexpected error: %v", err) } wg.Wait() - <-ctx.Done() // wait for recording to finish + select { + case <-done: + case <-time.After(30 * time.Second): + t.Fatal("timed out waiting for recording") + } var ch sessionrecording.CastHeader if err := json.NewDecoder(bytes.NewReader(recording)).Decode(&ch); err != nil { t.Fatal(err) @@ -802,6 +789,7 @@ func TestSSHAuthFlow(t *testing.T) { if runtime.GOOS != "linux" && runtime.GOOS != "darwin" { t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS) } + varRoot := t.TempDir() acceptRule := newSSHRule(&tailcfg.SSHAction{ Accept: true, Message: "Welcome to Tailscale SSH!", @@ -828,6 +816,7 @@ func TestSSHAuthFlow(t *testing.T) { name: "no-policy", state: &localState{ sshEnabled: true, + varRoot: varRoot, }, authErr: true, wantBanners: []string{"tailscale: tailnet policy does not permit you to SSH to this node\n"}, @@ -836,6 +825,7 @@ func TestSSHAuthFlow(t *testing.T) { name: "user-mismatch", state: &localState{ sshEnabled: true, + varRoot: varRoot, matchingRule: bobRule, }, authErr: true, @@ -845,6 +835,7 @@ func TestSSHAuthFlow(t *testing.T) { name: "accept", state: &localState{ sshEnabled: true, + varRoot: varRoot, matchingRule: acceptRule, }, wantBanners: []string{"Welcome to Tailscale SSH!"}, @@ -853,6 +844,7 @@ func TestSSHAuthFlow(t *testing.T) { name: "reject", state: &localState{ sshEnabled: true, + varRoot: varRoot, matchingRule: rejectRule, }, wantBanners: []string{"Go Away!"}, @@ -862,6 +854,7 @@ func TestSSHAuthFlow(t *testing.T) { name: "simple-check", state: &localState{ sshEnabled: true, + varRoot: varRoot, matchingRule: newSSHRule(&tailcfg.SSHAction{ HoldAndDelegate: "https://unused/ssh-action/accept", }), @@ -875,6 +868,7 @@ func TestSSHAuthFlow(t *testing.T) { name: "multi-check", state: &localState{ sshEnabled: true, + varRoot: varRoot, matchingRule: newSSHRule(&tailcfg.SSHAction{ Message: "First", HoldAndDelegate: "https://unused/ssh-action/check1", @@ -893,6 +887,7 @@ func TestSSHAuthFlow(t *testing.T) { name: "check-reject", state: &localState{ sshEnabled: true, + varRoot: varRoot, matchingRule: newSSHRule(&tailcfg.SSHAction{ Message: "First", HoldAndDelegate: "https://unused/ssh-action/reject", @@ -909,6 +904,7 @@ func TestSSHAuthFlow(t *testing.T) { sshUser: "alice+password", state: &localState{ sshEnabled: true, + varRoot: varRoot, matchingRule: acceptRule, }, usesPassword: true, @@ -989,9 +985,7 @@ func TestSSHAuthFlow(t *testing.T) { } var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() + wg.Go(func() { c, chans, reqs, err := testssh.NewClientConn(sc, sc.RemoteAddr().String(), cfg) if err != nil { if !tc.authErr { @@ -1015,7 +1009,7 @@ func TestSSHAuthFlow(t *testing.T) { if err != nil { t.Errorf("client: %v", err) } - }() + }) if err := s.HandleSSHConn(dc); err != nil { t.Errorf("unexpected error: %v", err) } @@ -1114,7 +1108,7 @@ func TestSSH(t *testing.T) { sc.finalAction = sc.action0 sc.authCompleted.Store(true) - sc.Handler = func(s ssh.Session) { + sc.Handler = func(s gliderssh.Session) { sc.newSSHSession(s).run() } @@ -1229,8 +1223,8 @@ func TestSSH(t *testing.T) { func parseEnv(out []byte) map[string]string { e := map[string]string{} for line := range lineiter.Bytes(out) { - if i := bytes.IndexByte(line, '='); i != -1 { - e[string(line[:i])] = string(line[i+1:]) + if before, after, ok := bytes.Cut(line, []byte{'='}); ok { + e[string(before)] = string(after) } } return e diff --git a/ssh/tailssh/testcontainers/Dockerfile b/ssh/tailssh/testcontainers/Dockerfile index 4ef1c1eb0bb7c..9d662ca1ad597 100644 --- a/ssh/tailssh/testcontainers/Dockerfile +++ b/ssh/tailssh/testcontainers/Dockerfile @@ -28,60 +28,68 @@ COPY tailssh.test . RUN chmod 755 tailscaled -RUN echo "First run tests normally." -RUN eval `ssh-agent -s` && TAILSCALED_PATH=`pwd`tailscaled ./tailssh.test -test.v -test.run TestSSHAgentForwarding -RUN if echo "$BASE" | grep "ubuntu:"; then rm -Rf /home/testuser; fi -RUN TAILSCALED_PATH=`pwd`tailscaled ./tailssh.test -test.v -test.run TestIntegrationSFTP -RUN if echo "$BASE" | grep "ubuntu:"; then rm -Rf /home/testuser; fi -RUN TAILSCALED_PATH=`pwd`tailscaled ./tailssh.test -test.v -test.run TestIntegrationSCP -RUN if echo "$BASE" | grep "ubuntu:"; then rm -Rf /home/testuser; fi -RUN TAILSCALED_PATH=`pwd`tailscaled ./tailssh.test -test.v -test.run TestIntegrationSSH -RUN if echo "$BASE" | grep "ubuntu:"; then rm -Rf /home/testuser; fi -RUN TAILSCALED_PATH=`pwd`tailscaled ./tailssh.test -test.v -test.run TestIntegrationParamiko +# Run tests normally. +# On Ubuntu, delete testuser's home directory between tests to verify +# that PAM's pam_mkhomedir recreates it each time. +RUN set -e && \ + eval $(ssh-agent -s) && \ + TAILSCALED_PATH=$(pwd)/tailscaled ./tailssh.test -test.v -test.run TestSSHAgentForwarding && \ + if echo "$BASE" | grep -q "ubuntu:"; then rm -Rf /home/testuser; fi && \ + TAILSCALED_PATH=$(pwd)/tailscaled ./tailssh.test -test.v -test.run TestIntegrationSFTP && \ + if echo "$BASE" | grep -q "ubuntu:"; then rm -Rf /home/testuser; fi && \ + TAILSCALED_PATH=$(pwd)/tailscaled ./tailssh.test -test.v -test.run TestIntegrationSCP && \ + if echo "$BASE" | grep -q "ubuntu:"; then rm -Rf /home/testuser; fi && \ + TAILSCALED_PATH=$(pwd)/tailscaled ./tailssh.test -test.v -test.run TestIntegrationSSH && \ + if echo "$BASE" | grep -q "ubuntu:"; then rm -Rf /home/testuser; fi && \ + TAILSCALED_PATH=$(pwd)/tailscaled ./tailssh.test -test.v -test.run TestIntegrationParamiko && \ + TAILSCALED_PATH=$(pwd)/tailscaled ./tailssh.test -test.v -test.run TestLocalUnixForwarding && \ + TAILSCALED_PATH=$(pwd)/tailscaled ./tailssh.test -test.v -test.run TestReverseUnixForwarding && \ + TAILSCALED_PATH=$(pwd)/tailscaled ./tailssh.test -test.v -test.run TestUnixForwardingDenied && \ + TAILSCALED_PATH=$(pwd)/tailscaled ./tailssh.test -test.v -test.run TestUnixForwardingPathRestriction -RUN echo "Then run tests as non-root user testuser and make sure tests still pass." -RUN touch /tmp/tailscalessh.log -RUN chown testuser:groupone /tmp/tailscalessh.log -RUN TAILSCALED_PATH=`pwd`tailscaled eval `su -m testuser -c ssh-agent -s` && su -m testuser -c "./tailssh.test -test.v -test.run TestSSHAgentForwarding" -RUN TAILSCALED_PATH=`pwd`tailscaled su -m testuser -c "./tailssh.test -test.v -test.run TestIntegration TestDoDropPrivileges" -RUN echo "Also, deny everyone access to the user's home directory and make sure non file-related tests still pass." -RUN mkdir -p /home/testuser && chown testuser:groupone /home/testuser && chmod 0000 /home/testuser -RUN TAILSCALED_PATH=`pwd`tailscaled SKIP_FILE_OPS=1 su -m testuser -c "./tailssh.test -test.v -test.run TestIntegrationSSH" -RUN chmod 0755 /home/testuser -RUN chown root:root /tmp/tailscalessh.log +# Run tests as non-root user testuser and make sure tests still pass. +RUN set -e && \ + touch /tmp/tailscalessh.log && \ + chown testuser:groupone /tmp/tailscalessh.log && \ + export TAILSCALED_PATH=$(pwd)/tailscaled && \ + eval $(su -m testuser -c "ssh-agent -s") && \ + su -m testuser -c "./tailssh.test -test.v -test.run 'TestSSHAgentForwarding|TestIntegration|TestDoDropPrivileges'" && \ + echo "Also, deny everyone access to the user's home directory and make sure non file-related tests still pass." && \ + mkdir -p /home/testuser && chown testuser:groupone /home/testuser && chmod 0000 /home/testuser && \ + SKIP_FILE_OPS=1 su -m testuser -c "./tailssh.test -test.v -test.run TestIntegrationSSH" && \ + chmod 0755 /home/testuser && \ + chown root:root /tmp/tailscalessh.log -RUN if echo "$BASE" | grep "ubuntu:"; then \ - echo "Then run tests in a system that's pretending to be SELinux in enforcing mode" && \ - # Remove execute permissions for /usr/bin/login so that it fails. +# On Ubuntu, run tests pretending to be SELinux in enforcing mode. +RUN if echo "$BASE" | grep -q "ubuntu:"; then \ + set -e && \ + echo "Run tests in a system that's pretending to be SELinux in enforcing mode" && \ mv /usr/bin/login /tmp/login_orig && \ - # Use nonsense for /usr/bin/login so that it fails. - # It's not the same failure mode as in SELinux, but failure is good enough for test. echo "adsfasdfasdf" > /usr/bin/login && \ chmod 755 /usr/bin/login && \ - # Simulate getenforce command printf "#!/bin/bash\necho 'Enforcing'" > /usr/bin/getenforce && \ chmod 755 /usr/bin/getenforce && \ - eval `ssh-agent -s` && TAILSCALED_PATH=`pwd`tailscaled ./tailssh.test -test.v -test.run TestSSHAgentForwarding && \ - TAILSCALED_PATH=`pwd`tailscaled ./tailssh.test -test.v -test.run TestIntegration && \ + eval $(ssh-agent -s) && \ + TAILSCALED_PATH=$(pwd)/tailscaled ./tailssh.test -test.v -test.run 'TestSSHAgentForwarding|TestIntegration' && \ mv /tmp/login_orig /usr/bin/login && \ rm /usr/bin/getenforce \ ; fi -RUN echo "Then remove the login command and make sure tests still pass." -RUN rm `which login` -RUN eval `ssh-agent -s` && TAILSCALED_PATH=`pwd`tailscaled ./tailssh.test -test.v -test.run TestSSHAgentForwarding -RUN if echo "$BASE" | grep "ubuntu:"; then rm -Rf /home/testuser; fi -RUN TAILSCALED_PATH=`pwd`tailscaled ./tailssh.test -test.v -test.run TestIntegrationSFTP -RUN if echo "$BASE" | grep "ubuntu:"; then rm -Rf /home/testuser; fi -RUN TAILSCALED_PATH=`pwd`tailscaled ./tailssh.test -test.v -test.run TestIntegrationSCP -RUN if echo "$BASE" | grep "ubuntu:"; then rm -Rf /home/testuser; fi -RUN TAILSCALED_PATH=`pwd`tailscaled ./tailssh.test -test.v -test.run TestIntegrationSSH - -RUN echo "Then remove the su command and make sure tests still pass." -RUN chown root:root /tmp/tailscalessh.log -RUN rm `which su` -RUN eval `ssh-agent -s` && TAILSCALED_PATH=`pwd`tailscaled ./tailssh.test -test.v -test.run TestSSHAgentForwarding -RUN TAILSCALED_PATH=`pwd`tailscaled ./tailssh.test -test.v -test.run TestIntegration +# Remove the login command and make sure tests still pass. +RUN set -e && \ + rm $(which login) && \ + eval $(ssh-agent -s) && \ + TAILSCALED_PATH=$(pwd)/tailscaled ./tailssh.test -test.v -test.run TestSSHAgentForwarding && \ + if echo "$BASE" | grep -q "ubuntu:"; then rm -Rf /home/testuser; fi && \ + TAILSCALED_PATH=$(pwd)/tailscaled ./tailssh.test -test.v -test.run TestIntegrationSFTP && \ + if echo "$BASE" | grep -q "ubuntu:"; then rm -Rf /home/testuser; fi && \ + TAILSCALED_PATH=$(pwd)/tailscaled ./tailssh.test -test.v -test.run TestIntegrationSCP && \ + if echo "$BASE" | grep -q "ubuntu:"; then rm -Rf /home/testuser; fi && \ + TAILSCALED_PATH=$(pwd)/tailscaled ./tailssh.test -test.v -test.run TestIntegrationSSH -RUN echo "Test doDropPrivileges" -RUN TAILSCALED_PATH=`pwd`tailscaled ./tailssh.test -test.v -test.run TestDoDropPrivileges +# Remove the su command and make sure tests still pass. +RUN set -e && \ + chown root:root /tmp/tailscalessh.log && \ + rm $(which su) && \ + eval $(ssh-agent -s) && \ + TAILSCALED_PATH=$(pwd)/tailscaled ./tailssh.test -test.v -test.run 'TestSSHAgentForwarding|TestIntegration|TestDoDropPrivileges' diff --git a/ssh/tailssh/user.go b/ssh/tailssh/user.go index 7da6bb4eb387f..0d2bf31e7f0d5 100644 --- a/ssh/tailssh/user.go +++ b/ssh/tailssh/user.go @@ -104,7 +104,7 @@ func defaultPathForUser(u *user.User) string { if isRoot { return "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin" } - return "/usr/local/bin:/usr/bin:/bin:/usr/bn/games" + return "/usr/local/bin:/usr/bin:/bin:/usr/local/games:/usr/games" case distro.NixOS: return defaultPathForUserOnNixOS(u) } diff --git a/syncs/shardedint_test.go b/syncs/shardedint_test.go index 8c3f7ef7bd915..c50e411d03f49 100644 --- a/syncs/shardedint_test.go +++ b/syncs/shardedint_test.go @@ -30,7 +30,7 @@ func BenchmarkShardedInt(b *testing.B) { }) }) - b.Run("sharded int", func(b *testing.B) { + b.Run("sharded-int", func(b *testing.B) { m := NewShardedInt() b.RunParallel(func(pb *testing.PB) { for pb.Next() { @@ -60,16 +60,16 @@ func TestShardedInt(t *testing.T) { } }) - t.Run("high concurrency", func(t *testing.T) { + t.Run("high-concurrency", func(t *testing.T) { m := NewShardedInt() wg := sync.WaitGroup{} numWorkers := 1000 numIncrements := 1000 wg.Add(numWorkers) - for i := 0; i < numWorkers; i++ { + for range numWorkers { go func() { defer wg.Done() - for i := 0; i < numIncrements; i++ { + for range numIncrements { m.Add(1) } }() @@ -83,7 +83,7 @@ func TestShardedInt(t *testing.T) { } }) - t.Run("encoding.TextAppender", func(t *testing.T) { + t.Run("encoding-TextAppender", func(t *testing.T) { m := NewShardedInt() m.Add(1) b := make([]byte, 0, 10) diff --git a/syncs/shardvalue_test.go b/syncs/shardvalue_test.go index 1dd0a542e60c2..ab34527abd77f 100644 --- a/syncs/shardvalue_test.go +++ b/syncs/shardvalue_test.go @@ -66,10 +66,10 @@ func TestShardValue(t *testing.T) { iterations := 10000 var wg sync.WaitGroup wg.Add(goroutines) - for i := 0; i < goroutines; i++ { + for range goroutines { go func() { defer wg.Done() - for i := 0; i < iterations; i++ { + for range iterations { sv.One(func(v *intVal) { v.Add(1) }) diff --git a/syncs/syncs_test.go b/syncs/syncs_test.go index 81fcccbf63aca..1e79448ad961e 100644 --- a/syncs/syncs_test.go +++ b/syncs/syncs_test.go @@ -6,6 +6,7 @@ package syncs import ( "context" "io" + "maps" "os" "sync" "testing" @@ -226,9 +227,7 @@ func TestMap(t *testing.T) { } got := map[string]int{} want := map[string]int{"one": 1, "two": 2, "three": 3} - for k, v := range m.All() { - got[k] = v - } + maps.Insert(got, m.All()) if d := cmp.Diff(got, want); d != "" { t.Errorf("Range mismatch (-got +want):\n%s", d) } @@ -243,9 +242,7 @@ func TestMap(t *testing.T) { m.Delete("noexist") got = map[string]int{} want = map[string]int{} - for k, v := range m.All() { - got[k] = v - } + maps.Insert(got, m.All()) if d := cmp.Diff(got, want); d != "" { t.Errorf("Range mismatch (-got +want):\n%s", d) } diff --git a/tailcfg/proto_port_range_test.go b/tailcfg/proto_port_range_test.go index c0c5ff5d5cb76..2fa6c0da0750d 100644 --- a/tailcfg/proto_port_range_test.go +++ b/tailcfg/proto_port_range_test.go @@ -18,31 +18,35 @@ func TestProtoPortRangeParsing(t *testing.T) { return PortRange{First: s, Last: e} } tests := []struct { - in string - out ProtoPortRange - err error + name string + in string + out ProtoPortRange + err error }{ - {in: "tcp:80", out: ProtoPortRange{Proto: int(ipproto.TCP), Ports: pr(80, 80)}}, - {in: "80", out: ProtoPortRange{Ports: pr(80, 80)}}, - {in: "*", out: ProtoPortRange{Ports: PortRangeAny}}, - {in: "*:*", out: ProtoPortRange{Ports: PortRangeAny}}, - {in: "tcp:*", out: ProtoPortRange{Proto: int(ipproto.TCP), Ports: PortRangeAny}}, + {name: "tcp-80", in: "tcp:80", out: ProtoPortRange{Proto: int(ipproto.TCP), Ports: pr(80, 80)}}, + {name: "80", in: "80", out: ProtoPortRange{Ports: pr(80, 80)}}, + {name: "star", in: "*", out: ProtoPortRange{Ports: PortRangeAny}}, + {name: "star-star", in: "*:*", out: ProtoPortRange{Ports: PortRangeAny}}, + {name: "tcp-star", in: "tcp:*", out: ProtoPortRange{Proto: int(ipproto.TCP), Ports: PortRangeAny}}, { - in: "tcp:", - err: vizerror.Errorf("invalid port list: %#v", ""), + name: "tcp-empty-port", + in: "tcp:", + err: vizerror.Errorf("invalid port list: %#v", ""), }, { - in: ":80", - err: errEmptyProtocol, + name: "empty-proto-80", + in: ":80", + err: errEmptyProtocol, }, { - in: "", - err: errEmptyString, + name: "empty-string", + in: "", + err: errEmptyString, }, } for _, tc := range tests { - t.Run(tc.in, func(t *testing.T) { + t.Run(tc.name, func(t *testing.T) { var ppr ProtoPortRange err := ppr.UnmarshalText([]byte(tc.in)) if tc.err != err { diff --git a/tailcfg/tailcfg.go b/tailcfg/tailcfg.go index b49791be6fb39..0cb7597c345a1 100644 --- a/tailcfg/tailcfg.go +++ b/tailcfg/tailcfg.go @@ -180,7 +180,12 @@ type CapabilityVersion int // - 131: 2025-11-25: client respects [NodeAttrDefaultAutoUpdate] // - 132: 2026-02-13: client respects [NodeAttrDisableHostsFileUpdates] // - 133: 2026-02-17: client understands [NodeAttrForceRegisterMagicDNSIPv4Only]; MagicDNS IPv6 registered w/ OS by default -const CurrentCapabilityVersion CapabilityVersion = 133 +// - 134: 2026-03-09: Client understands [NodeAttrDisableAndroidBindToActiveNetwork] +// - 135: 2026-03-30: Client understands [NodeAttrCacheNetworkMaps] +// - 136: 2026-04-09: Client understands [NodeAttrDisableLinuxCGNATDropRule] +// - 137: 2026-04-15: Client handles 429 responses to /machine/register. +// - 138: 2026-03-31: can handle C2N /debug/tka. +const CurrentCapabilityVersion CapabilityVersion = 138 // ID is an integer ID for a user, node, or login allocated by the // control plane. @@ -282,6 +287,13 @@ type UserProfile struct { LoginName string // "alice@smith.com"; for display purposes only (provider is not listed) DisplayName string // "Alice Smith" ProfilePicURL string `json:",omitzero"` + + // Groups is a subset of SCIM groups (e.g. "engineering@example.com") + // or group names in the tailnet policy document (e.g. "group:eng") + // that contain this user and that the coordination server was + // configured to report to this node. + // The list is always sorted when loaded from storage. + Groups []string `json:",omitempty"` } func (p *UserProfile) Equal(p2 *UserProfile) bool { @@ -294,7 +306,8 @@ func (p *UserProfile) Equal(p2 *UserProfile) bool { return p.ID == p2.ID && p.LoginName == p2.LoginName && p.DisplayName == p2.DisplayName && - p.ProfilePicURL == p2.ProfilePicURL + p.ProfilePicURL == p2.ProfilePicURL && + slices.Equal(p.Groups, p2.Groups) } // RawMessage is a raw encoded JSON value. It implements Marshaler and @@ -2260,7 +2273,7 @@ type ClientVersion struct { // UrgentSecurityUpdate is set when the client is missing an important // security update. That update may be in LatestVersion or earlier. - // UrgentSecurityUpdate should not be set if RunningLatest is false. + // UrgentSecurityUpdate should not be set if RunningLatest is true. UrgentSecurityUpdate bool `json:",omitempty"` // Notify is whether the client should do an OS-specific notification about @@ -2437,6 +2450,18 @@ type Oauth2Token struct { // These are also referred to as "Node Attributes" in the ACL policy file. type NodeCapability string +// NodeCapabilityPrefix is a prefix for [NodeCapMap] keys that share a common +// namespace, where each entry represents a distinct named instance (e.g. one +// per service). The full key is formed by concatenating the prefix with the +// instance name. +type NodeCapabilityPrefix string + +// ToAttribute returns the full [NodeCapability] key for the given value under +// this prefix, of the form prefix+value. +func (p NodeCapabilityPrefix) ToAttribute(value string) NodeCapability { + return NodeCapability(string(p) + value) +} + const ( CapabilityFileSharing NodeCapability = "https://tailscale.com/cap/file-sharing" CapabilityAdmin NodeCapability = "https://tailscale.com/cap/is-admin" @@ -2450,11 +2475,21 @@ const ( // CapabilityMacUIV2 makes the macOS GUI enable its v2 mode. CapabilityMacUIV2 NodeCapability = "https://tailscale.com/cap/mac-ui-v2" + // CapabilityServicesInDesktopClients enables services list/menu/section in desktop clients. + // If this capability is not present, desktop clients should not show services. + CapabilityServicesInDesktopClients NodeCapability = "https://tailscale.com/cap/services-in-desktop-clients" + // CapabilityBindToInterfaceByRoute changes how Darwin nodes create // sockets (in the net/netns package). See that package for more // details on the behaviour of this capability. CapabilityBindToInterfaceByRoute NodeCapability = "https://tailscale.com/cap/bind-to-interface-by-route" + // NodeAttrDisableAndroidBindToActiveNetwork disables binding sockets to the + // currently active network on Android, which is enabled by default. + // This allows the control plane to turn off the behavior if it causes + // problems. + NodeAttrDisableAndroidBindToActiveNetwork NodeCapability = "disable-android-bind-to-active-network" + // CapabilityDebugDisableAlternateDefaultRouteInterface changes how Darwin // nodes get the default interface. There is an optional hook (used by the // macOS and iOS clients) to override the default interface, this capability @@ -2569,21 +2604,6 @@ const ( // This cannot be set simultaneously with NodeAttrLinuxMustUseIPTables. NodeAttrLinuxMustUseNfTables NodeCapability = "linux-netfilter?v=nftables" - // NodeAttrDisableSeamlessKeyRenewal disables seamless key renewal, which is - // enabled by default in clients as of 2025-09-17 (1.90 and later). - // - // We will use this attribute to manage the rollout, and disable seamless in - // clients with known bugs. - // http://go/seamless-key-renewal - NodeAttrDisableSeamlessKeyRenewal NodeCapability = "disable-seamless-key-renewal" - - // NodeAttrSeamlessKeyRenewal was used to opt-in to seamless key renewal - // during its private alpha. - // - // Deprecated: NodeAttrSeamlessKeyRenewal is deprecated as of CapabilityVersion 126, - // because seamless key renewal is now enabled by default. - NodeAttrSeamlessKeyRenewal NodeCapability = "seamless-key-renewal" - // NodeAttrProbeUDPLifetime makes the client probe UDP path lifetime at the // tail end of an active direct connection in magicsock. NodeAttrProbeUDPLifetime NodeCapability = "probe-udp-lifetime" @@ -2755,6 +2775,29 @@ const ( // See https://github.com/tailscale/tailscale/issues/15404. // TODO(bradfitz): remove this a few releases after 2026-02-16. NodeAttrForceRegisterMagicDNSIPv4Only NodeCapability = "force-register-magicdns-ipv4-only" + + // NodeAttrCacheNetworkMaps instructs the node to persistently cache network + // maps and use them to establish peer connectivity on start, if doing so is + // supported by the client and storage is available. When this attribute is + // absent (or removed), a node that supports netmap caching will ignore and + // discard existing cached maps, and will not store any. + NodeAttrCacheNetworkMaps NodeCapability = "cache-network-maps" + + // NodeAttrDisableLinuxCGNATDropRule tells Linux clients to not insert a + // blanket firewall DROP rule for inbound traffic from the CGNAT IP range + // that does not originate from the Tailscale network interface. + // This enables access to off-tailnet endpoints within that IP range. + NodeAttrDisableLinuxCGNATDropRule NodeCapability = "disable-linux-cgnat-drop-rule" +) + +const ( + // NodeAttrPrefixServices is the prefix for per-service [NodeCapMap] + // entries describing Services visible (accessible) to this node. + // Each value under such a key is of type [ServiceDetails]. + // The suffix after the prefix is an opaque server-chosen identifier; + // consumers must use [ServiceDetails.Name] as the canonical service name + // rather than parsing it from the map key. + NodeAttrPrefixServices NodeCapabilityPrefix = "services/" ) // SetDNSRequest is a request to add a DNS record. @@ -3295,6 +3338,51 @@ const LBHeader = "Ts-Lb" // this client is hosting can be ignored. type ServiceIPMappings map[ServiceName][]netip.Addr +// ServiceAction describes an action that a Tailscale +// client can invoke for a [ServiceDetails]. +type ServiceAction struct { + // Type is the action's identifier i.e. a unique slug corresponding to a well + // known action. It drives icon selection and client application matching. + Type string + + // Port is the target TCP port for this action. It must match one of + // the specific (non-range) TCP ports listed in the enclosing + // [ServiceDetails.Ports]. + Port uint16 + + // DisplayName is an optional human-readable label which may be shown + // in client menus when there are multiple actions to select from. + // If empty, a display name may be inferred from the Type field. + DisplayName string `json:",omitzero"` +} + +// ServiceDetails describes a Service visible to this node. +// It is the value type stored under [NodeAttrPrefixServices]+serviceName keys in [NodeCapMap]. +type ServiceDetails struct { + // Name is the name of the Service, of the form "svc:dns-label". + Name ServiceName + + // DisplayName is an optional human-readable label for the service. + // If empty, Name is used as a fallback by clients. + DisplayName string `json:",omitzero"` + + // Addrs are the IP addresses (IPv4 and IPv6) assigned to this Service. + Addrs []netip.Addr `json:",omitempty"` + + // Ports are the protocol/port combinations the Service accepts. + Ports []ProtoPortRange `json:",omitempty"` + + // Actions is an optional list of actions describing how a client may + // interact with this service. Each action maps a [ServiceAction.Type] to a + // specific TCP port; the port must match one of the concrete (non-range) + // ports listed in Ports. + // + // Multiple actions may reference the same port. Not every port requires + // a corresponding action. When Actions has length zero, clients may infer + // default interactions from Ports. + Actions []ServiceAction `json:",omitzero"` +} + // ClientAuditAction represents an auditable action that a client can report to the // control plane. These actions must correspond to the supported actions // in the control plane. diff --git a/tailcfg/tailcfg_clone.go b/tailcfg/tailcfg_clone.go index a60f301d763c7..df2d6d9aa7e8d 100644 --- a/tailcfg/tailcfg_clone.go +++ b/tailcfg/tailcfg_clone.go @@ -13,7 +13,6 @@ import ( "tailscale.com/types/dnstype" "tailscale.com/types/key" "tailscale.com/types/opt" - "tailscale.com/types/ptr" "tailscale.com/types/structs" "tailscale.com/types/tkatype" ) @@ -53,10 +52,10 @@ func (src *Node) Clone() *Node { dst.Tags = append(src.Tags[:0:0], src.Tags...) dst.PrimaryRoutes = append(src.PrimaryRoutes[:0:0], src.PrimaryRoutes...) if dst.LastSeen != nil { - dst.LastSeen = ptr.To(*src.LastSeen) + dst.LastSeen = new(*src.LastSeen) } if dst.Online != nil { - dst.Online = ptr.To(*src.Online) + dst.Online = new(*src.Online) } dst.Capabilities = append(src.Capabilities[:0:0], src.Capabilities...) if dst.CapMap != nil { @@ -66,10 +65,10 @@ func (src *Node) Clone() *Node { } } if dst.SelfNodeV4MasqAddrForThisPeer != nil { - dst.SelfNodeV4MasqAddrForThisPeer = ptr.To(*src.SelfNodeV4MasqAddrForThisPeer) + dst.SelfNodeV4MasqAddrForThisPeer = new(*src.SelfNodeV4MasqAddrForThisPeer) } if dst.SelfNodeV6MasqAddrForThisPeer != nil { - dst.SelfNodeV6MasqAddrForThisPeer = ptr.To(*src.SelfNodeV6MasqAddrForThisPeer) + dst.SelfNodeV6MasqAddrForThisPeer = new(*src.SelfNodeV6MasqAddrForThisPeer) } if src.ExitNodeDNSResolvers != nil { dst.ExitNodeDNSResolvers = make([]*dnstype.Resolver, len(src.ExitNodeDNSResolvers)) @@ -139,10 +138,10 @@ func (src *Hostinfo) Clone() *Hostinfo { dst.NetInfo = src.NetInfo.Clone() dst.SSH_HostKeys = append(src.SSH_HostKeys[:0:0], src.SSH_HostKeys...) if dst.Location != nil { - dst.Location = ptr.To(*src.Location) + dst.Location = new(*src.Location) } if dst.TPM != nil { - dst.TPM = ptr.To(*src.TPM) + dst.TPM = new(*src.TPM) } return dst } @@ -263,8 +262,19 @@ func (src *DNSConfig) Clone() *DNSConfig { } if dst.Routes != nil { dst.Routes = map[string][]*dnstype.Resolver{} - for k := range src.Routes { - dst.Routes[k] = append([]*dnstype.Resolver{}, src.Routes[k]...) + for k, sv := range src.Routes { + if sv == nil { + dst.Routes[k] = nil + continue + } + dst.Routes[k] = make([]*dnstype.Resolver, len(sv)) + for i := range sv { + if sv[i] == nil { + dst.Routes[k][i] = nil + } else { + dst.Routes[k][i] = sv[i].Clone() + } + } } } if src.FallbackResolvers != nil { @@ -331,7 +341,7 @@ func (src *RegisterResponseAuth) Clone() *RegisterResponseAuth { dst := new(RegisterResponseAuth) *dst = *src if dst.Oauth2Token != nil { - dst.Oauth2Token = ptr.To(*src.Oauth2Token) + dst.Oauth2Token = new(*src.Oauth2Token) } return dst } @@ -355,7 +365,7 @@ func (src *RegisterRequest) Clone() *RegisterRequest { dst.Hostinfo = src.Hostinfo.Clone() dst.NodeKeySignature = append(src.NodeKeySignature[:0:0], src.NodeKeySignature...) if dst.Timestamp != nil { - dst.Timestamp = ptr.To(*src.Timestamp) + dst.Timestamp = new(*src.Timestamp) } dst.DeviceCert = append(src.DeviceCert[:0:0], src.DeviceCert...) dst.Signature = append(src.Signature[:0:0], src.Signature...) @@ -413,7 +423,7 @@ func (src *DERPRegion) Clone() *DERPRegion { if src.Nodes[i] == nil { dst.Nodes[i] = nil } else { - dst.Nodes[i] = ptr.To(*src.Nodes[i]) + dst.Nodes[i] = new(*src.Nodes[i]) } } } @@ -497,7 +507,7 @@ func (src *SSHRule) Clone() *SSHRule { dst := new(SSHRule) *dst = *src if dst.RuleExpires != nil { - dst.RuleExpires = ptr.To(*src.RuleExpires) + dst.RuleExpires = new(*src.RuleExpires) } if src.Principals != nil { dst.Principals = make([]*SSHPrincipal, len(src.Principals)) @@ -534,7 +544,7 @@ func (src *SSHAction) Clone() *SSHAction { *dst = *src dst.Recorders = append(src.Recorders[:0:0], src.Recorders...) if dst.OnRecordingFailure != nil { - dst.OnRecordingFailure = ptr.To(*src.OnRecordingFailure) + dst.OnRecordingFailure = new(*src.OnRecordingFailure) } return dst } @@ -621,6 +631,7 @@ func (src *UserProfile) Clone() *UserProfile { } dst := new(UserProfile) *dst = *src + dst.Groups = append(src.Groups[:0:0], src.Groups...) return dst } @@ -630,6 +641,7 @@ var _UserProfileCloneNeedsRegeneration = UserProfile(struct { LoginName string DisplayName string ProfilePicURL string + Groups []string }{}) // Clone makes a deep copy of VIPService. diff --git a/tailcfg/tailcfg_test.go b/tailcfg/tailcfg_test.go index f649e43ab57b8..8dd9191b63c84 100644 --- a/tailcfg/tailcfg_test.go +++ b/tailcfg/tailcfg_test.go @@ -17,13 +17,12 @@ import ( "tailscale.com/tstest/deptest" "tailscale.com/types/key" "tailscale.com/types/opt" - "tailscale.com/types/ptr" "tailscale.com/util/must" ) func fieldsOf(t reflect.Type) (fields []string) { - for i := range t.NumField() { - fields = append(fields, t.Field(i).Name) + for field := range t.Fields() { + fields = append(fields, field.Name) } return } @@ -539,22 +538,22 @@ func TestNodeEqual(t *testing.T) { }, { &Node{}, - &Node{SelfNodeV4MasqAddrForThisPeer: ptr.To(netip.MustParseAddr("100.64.0.1"))}, + &Node{SelfNodeV4MasqAddrForThisPeer: new(netip.MustParseAddr("100.64.0.1"))}, false, }, { - &Node{SelfNodeV4MasqAddrForThisPeer: ptr.To(netip.MustParseAddr("100.64.0.1"))}, - &Node{SelfNodeV4MasqAddrForThisPeer: ptr.To(netip.MustParseAddr("100.64.0.1"))}, + &Node{SelfNodeV4MasqAddrForThisPeer: new(netip.MustParseAddr("100.64.0.1"))}, + &Node{SelfNodeV4MasqAddrForThisPeer: new(netip.MustParseAddr("100.64.0.1"))}, true, }, { &Node{}, - &Node{SelfNodeV6MasqAddrForThisPeer: ptr.To(netip.MustParseAddr("2001::3456"))}, + &Node{SelfNodeV6MasqAddrForThisPeer: new(netip.MustParseAddr("2001::3456"))}, false, }, { - &Node{SelfNodeV6MasqAddrForThisPeer: ptr.To(netip.MustParseAddr("2001::3456"))}, - &Node{SelfNodeV6MasqAddrForThisPeer: ptr.To(netip.MustParseAddr("2001::3456"))}, + &Node{SelfNodeV6MasqAddrForThisPeer: new(netip.MustParseAddr("2001::3456"))}, + &Node{SelfNodeV6MasqAddrForThisPeer: new(netip.MustParseAddr("2001::3456"))}, true, }, { @@ -842,12 +841,12 @@ func TestMarshalToRawMessageAndBack(t *testing.T) { capType: PeerCapability("foo"), }, { - name: "some values", + name: "some-values", val: testRule{Ports: []int{80, 443}, Name: "foo"}, capType: PeerCapability("foo"), }, { - name: "all values", + name: "all-values", val: testRule{Ports: []int{80, 443}, Name: "foo", ToggleOn: true, Groups: inner{Groups: []string{"foo", "bar"}}, Addrs: []netip.AddrPort{testip}}, capType: PeerCapability("foo"), }, diff --git a/tailcfg/tailcfg_view.go b/tailcfg/tailcfg_view.go index 7960000fd3d6a..9900efbcc3d63 100644 --- a/tailcfg/tailcfg_view.go +++ b/tailcfg/tailcfg_view.go @@ -2505,8 +2505,15 @@ func (v UserProfileView) ID() UserID { return v.Đļ.ID } func (v UserProfileView) LoginName() string { return v.Đļ.LoginName } // "Alice Smith" -func (v UserProfileView) DisplayName() string { return v.Đļ.DisplayName } -func (v UserProfileView) ProfilePicURL() string { return v.Đļ.ProfilePicURL } +func (v UserProfileView) DisplayName() string { return v.Đļ.DisplayName } +func (v UserProfileView) ProfilePicURL() string { return v.Đļ.ProfilePicURL } + +// Groups is a subset of SCIM groups (e.g. "engineering@example.com") +// or group names in the tailnet policy document (e.g. "group:eng") +// that contain this user and that the coordination server was +// configured to report to this node. +// The list is always sorted when loaded from storage. +func (v UserProfileView) Groups() views.Slice[string] { return views.SliceOf(v.Đļ.Groups) } func (v UserProfileView) Equal(v2 UserProfileView) bool { return v.Đļ.Equal(v2.Đļ) } // A compilation failure here means this code must be regenerated, with the command at the top of this file. @@ -2515,6 +2522,7 @@ var _UserProfileViewNeedsRegeneration = UserProfile(struct { LoginName string DisplayName string ProfilePicURL string + Groups []string }{}) // View returns a read-only view of VIPService. diff --git a/tempfork/gliderlabs/ssh/LICENSE b/tempfork/gliderlabs/ssh/LICENSE deleted file mode 100644 index 4a03f02a28185..0000000000000 --- a/tempfork/gliderlabs/ssh/LICENSE +++ /dev/null @@ -1,27 +0,0 @@ -Copyright (c) 2016 Glider Labs. All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above -copyright notice, this list of conditions and the following disclaimer -in the documentation and/or other materials provided with the -distribution. - * Neither the name of Glider Labs nor the names of its -contributors may be used to endorse or promote products derived from -this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/tempfork/gliderlabs/ssh/README.md b/tempfork/gliderlabs/ssh/README.md deleted file mode 100644 index 79b5b89fa8a94..0000000000000 --- a/tempfork/gliderlabs/ssh/README.md +++ /dev/null @@ -1,96 +0,0 @@ -# gliderlabs/ssh - -[![GoDoc](https://godoc.org/tailscale.com/tempfork/gliderlabs/ssh?status.svg)](https://godoc.org/github.com/gliderlabs/ssh) -[![CircleCI](https://img.shields.io/circleci/project/github/gliderlabs/ssh.svg)](https://circleci.com/gh/gliderlabs/ssh) -[![Go Report Card](https://goreportcard.com/badge/tailscale.com/tempfork/gliderlabs/ssh)](https://goreportcard.com/report/github.com/gliderlabs/ssh) -[![OpenCollective](https://opencollective.com/ssh/sponsors/badge.svg)](#sponsors) -[![Slack](http://slack.gliderlabs.com/badge.svg)](http://slack.gliderlabs.com) -[![Email Updates](https://img.shields.io/badge/updates-subscribe-yellow.svg)](https://app.convertkit.com/landing_pages/243312) - -> The Glider Labs SSH server package is dope. —[@bradfitz](https://twitter.com/bradfitz), Go team member - -This Go package wraps the [crypto/ssh -package](https://godoc.org/golang.org/x/crypto/ssh) with a higher-level API for -building SSH servers. The goal of the API was to make it as simple as using -[net/http](https://golang.org/pkg/net/http/), so the API is very similar: - -```go - package main - - import ( - "tailscale.com/tempfork/gliderlabs/ssh" - "io" - "log" - ) - - func main() { - ssh.Handle(func(s ssh.Session) { - io.WriteString(s, "Hello world\n") - }) - - log.Fatal(ssh.ListenAndServe(":2222", nil)) - } - -``` -This package was built by [@progrium](https://twitter.com/progrium) after working on nearly a dozen projects at Glider Labs using SSH and collaborating with [@shazow](https://twitter.com/shazow) (known for [ssh-chat](https://github.com/shazow/ssh-chat)). - -## Examples - -A bunch of great examples are in the `_examples` directory. - -## Usage - -[See GoDoc reference.](https://godoc.org/tailscale.com/tempfork/gliderlabs/ssh) - -## Contributing - -Pull requests are welcome! However, since this project is very much about API -design, please submit API changes as issues to discuss before submitting PRs. - -Also, you can [join our Slack](http://slack.gliderlabs.com) to discuss as well. - -## Roadmap - -* Non-session channel handlers -* Cleanup callback API -* 1.0 release -* High-level client? - -## Sponsors - -Become a sponsor and get your logo on our README on Github with a link to your site. [[Become a sponsor](https://opencollective.com/ssh#sponsor)] - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -## License - -[BSD](LICENSE) diff --git a/tempfork/gliderlabs/ssh/agent.go b/tempfork/gliderlabs/ssh/agent.go deleted file mode 100644 index 99e84c1e5c64c..0000000000000 --- a/tempfork/gliderlabs/ssh/agent.go +++ /dev/null @@ -1,83 +0,0 @@ -package ssh - -import ( - "io" - "net" - "os" - "path" - "sync" - - gossh "golang.org/x/crypto/ssh" -) - -const ( - agentRequestType = "auth-agent-req@openssh.com" - agentChannelType = "auth-agent@openssh.com" - - agentTempDir = "auth-agent" - agentListenFile = "listener.sock" -) - -// contextKeyAgentRequest is an internal context key for storing if the -// client requested agent forwarding -var contextKeyAgentRequest = &contextKey{"auth-agent-req"} - -// SetAgentRequested sets up the session context so that AgentRequested -// returns true. -func SetAgentRequested(ctx Context) { - ctx.SetValue(contextKeyAgentRequest, true) -} - -// AgentRequested returns true if the client requested agent forwarding. -func AgentRequested(sess Session) bool { - return sess.Context().Value(contextKeyAgentRequest) == true -} - -// NewAgentListener sets up a temporary Unix socket that can be communicated -// to the session environment and used for forwarding connections. -func NewAgentListener() (net.Listener, error) { - dir, err := os.MkdirTemp("", agentTempDir) - if err != nil { - return nil, err - } - l, err := net.Listen("unix", path.Join(dir, agentListenFile)) - if err != nil { - return nil, err - } - return l, nil -} - -// ForwardAgentConnections takes connections from a listener to proxy into the -// session on the OpenSSH channel for agent connections. It blocks and services -// connections until the listener stop accepting. -func ForwardAgentConnections(l net.Listener, s Session) { - sshConn := s.Context().Value(ContextKeyConn).(gossh.Conn) - for { - conn, err := l.Accept() - if err != nil { - return - } - go func(conn net.Conn) { - defer conn.Close() - channel, reqs, err := sshConn.OpenChannel(agentChannelType, nil) - if err != nil { - return - } - defer channel.Close() - go gossh.DiscardRequests(reqs) - var wg sync.WaitGroup - wg.Add(2) - go func() { - io.Copy(conn, channel) - conn.(*net.UnixConn).CloseWrite() - wg.Done() - }() - go func() { - io.Copy(channel, conn) - channel.CloseWrite() - wg.Done() - }() - wg.Wait() - }(conn) - } -} diff --git a/tempfork/gliderlabs/ssh/conn.go b/tempfork/gliderlabs/ssh/conn.go deleted file mode 100644 index ebef8845baccb..0000000000000 --- a/tempfork/gliderlabs/ssh/conn.go +++ /dev/null @@ -1,55 +0,0 @@ -package ssh - -import ( - "context" - "net" - "time" -) - -type serverConn struct { - net.Conn - - idleTimeout time.Duration - maxDeadline time.Time - closeCanceler context.CancelFunc -} - -func (c *serverConn) Write(p []byte) (n int, err error) { - c.updateDeadline() - n, err = c.Conn.Write(p) - if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil { - c.closeCanceler() - } - return -} - -func (c *serverConn) Read(b []byte) (n int, err error) { - c.updateDeadline() - n, err = c.Conn.Read(b) - if _, isNetErr := err.(net.Error); isNetErr && c.closeCanceler != nil { - c.closeCanceler() - } - return -} - -func (c *serverConn) Close() (err error) { - err = c.Conn.Close() - if c.closeCanceler != nil { - c.closeCanceler() - } - return -} - -func (c *serverConn) updateDeadline() { - switch { - case c.idleTimeout > 0: - idleDeadline := time.Now().Add(c.idleTimeout) - if idleDeadline.Unix() < c.maxDeadline.Unix() || c.maxDeadline.IsZero() { - c.Conn.SetDeadline(idleDeadline) - return - } - fallthrough - default: - c.Conn.SetDeadline(c.maxDeadline) - } -} diff --git a/tempfork/gliderlabs/ssh/context.go b/tempfork/gliderlabs/ssh/context.go deleted file mode 100644 index 505a43dbf3ffe..0000000000000 --- a/tempfork/gliderlabs/ssh/context.go +++ /dev/null @@ -1,155 +0,0 @@ -package ssh - -import ( - "context" - "encoding/hex" - "net" - "sync" - - gossh "golang.org/x/crypto/ssh" -) - -// contextKey is a value for use with context.WithValue. It's used as -// a pointer so it fits in an interface{} without allocation. -type contextKey struct { - name string -} - -var ( - // ContextKeyUser is a context key for use with Contexts in this package. - // The associated value will be of type string. - ContextKeyUser = &contextKey{"user"} - - // ContextKeySessionID is a context key for use with Contexts in this package. - // The associated value will be of type string. - ContextKeySessionID = &contextKey{"session-id"} - - // ContextKeyPermissions is a context key for use with Contexts in this package. - // The associated value will be of type *Permissions. - ContextKeyPermissions = &contextKey{"permissions"} - - // ContextKeyClientVersion is a context key for use with Contexts in this package. - // The associated value will be of type string. - ContextKeyClientVersion = &contextKey{"client-version"} - - // ContextKeyServerVersion is a context key for use with Contexts in this package. - // The associated value will be of type string. - ContextKeyServerVersion = &contextKey{"server-version"} - - // ContextKeyLocalAddr is a context key for use with Contexts in this package. - // The associated value will be of type net.Addr. - ContextKeyLocalAddr = &contextKey{"local-addr"} - - // ContextKeyRemoteAddr is a context key for use with Contexts in this package. - // The associated value will be of type net.Addr. - ContextKeyRemoteAddr = &contextKey{"remote-addr"} - - // ContextKeyServer is a context key for use with Contexts in this package. - // The associated value will be of type *Server. - ContextKeyServer = &contextKey{"ssh-server"} - - // ContextKeyConn is a context key for use with Contexts in this package. - // The associated value will be of type gossh.ServerConn. - ContextKeyConn = &contextKey{"ssh-conn"} - - // ContextKeyPublicKey is a context key for use with Contexts in this package. - // The associated value will be of type PublicKey. - ContextKeyPublicKey = &contextKey{"public-key"} -) - -// Context is a package specific context interface. It exposes connection -// metadata and allows new values to be easily written to it. It's used in -// authentication handlers and callbacks, and its underlying context.Context is -// exposed on Session in the session Handler. A connection-scoped lock is also -// embedded in the context to make it easier to limit operations per-connection. -type Context interface { - context.Context - sync.Locker - - // User returns the username used when establishing the SSH connection. - User() string - - // SessionID returns the session hash. - SessionID() string - - // ClientVersion returns the version reported by the client. - ClientVersion() string - - // ServerVersion returns the version reported by the server. - ServerVersion() string - - // RemoteAddr returns the remote address for this connection. - RemoteAddr() net.Addr - - // LocalAddr returns the local address for this connection. - LocalAddr() net.Addr - - // Permissions returns the Permissions object used for this connection. - Permissions() *Permissions - - // SetValue allows you to easily write new values into the underlying context. - SetValue(key, value interface{}) -} - -type sshContext struct { - context.Context - *sync.Mutex -} - -func newContext(srv *Server) (*sshContext, context.CancelFunc) { - innerCtx, cancel := context.WithCancel(context.Background()) - ctx := &sshContext{innerCtx, &sync.Mutex{}} - ctx.SetValue(ContextKeyServer, srv) - perms := &Permissions{&gossh.Permissions{}} - ctx.SetValue(ContextKeyPermissions, perms) - return ctx, cancel -} - -// this is separate from newContext because we will get ConnMetadata -// at different points so it needs to be applied separately -func applyConnMetadata(ctx Context, conn gossh.ConnMetadata) { - if ctx.Value(ContextKeySessionID) != nil { - return - } - ctx.SetValue(ContextKeySessionID, hex.EncodeToString(conn.SessionID())) - ctx.SetValue(ContextKeyClientVersion, string(conn.ClientVersion())) - ctx.SetValue(ContextKeyServerVersion, string(conn.ServerVersion())) - ctx.SetValue(ContextKeyUser, conn.User()) - ctx.SetValue(ContextKeyLocalAddr, conn.LocalAddr()) - ctx.SetValue(ContextKeyRemoteAddr, conn.RemoteAddr()) -} - -func (ctx *sshContext) SetValue(key, value interface{}) { - ctx.Context = context.WithValue(ctx.Context, key, value) -} - -func (ctx *sshContext) User() string { - return ctx.Value(ContextKeyUser).(string) -} - -func (ctx *sshContext) SessionID() string { - return ctx.Value(ContextKeySessionID).(string) -} - -func (ctx *sshContext) ClientVersion() string { - return ctx.Value(ContextKeyClientVersion).(string) -} - -func (ctx *sshContext) ServerVersion() string { - return ctx.Value(ContextKeyServerVersion).(string) -} - -func (ctx *sshContext) RemoteAddr() net.Addr { - if addr, ok := ctx.Value(ContextKeyRemoteAddr).(net.Addr); ok { - return addr - } - return nil -} - -func (ctx *sshContext) LocalAddr() net.Addr { - return ctx.Value(ContextKeyLocalAddr).(net.Addr) -} - -func (ctx *sshContext) Permissions() *Permissions { - return ctx.Value(ContextKeyPermissions).(*Permissions) -} diff --git a/tempfork/gliderlabs/ssh/context_test.go b/tempfork/gliderlabs/ssh/context_test.go deleted file mode 100644 index dcbd326b77809..0000000000000 --- a/tempfork/gliderlabs/ssh/context_test.go +++ /dev/null @@ -1,49 +0,0 @@ -//go:build glidertests - -package ssh - -import "testing" - -func TestSetPermissions(t *testing.T) { - t.Parallel() - permsExt := map[string]string{ - "foo": "bar", - } - session, _, cleanup := newTestSessionWithOptions(t, &Server{ - Handler: func(s Session) { - if _, ok := s.Permissions().Extensions["foo"]; !ok { - t.Fatalf("got %#v; want %#v", s.Permissions().Extensions, permsExt) - } - }, - }, nil, PasswordAuth(func(ctx Context, password string) bool { - ctx.Permissions().Extensions = permsExt - return true - })) - defer cleanup() - if err := session.Run(""); err != nil { - t.Fatal(err) - } -} - -func TestSetValue(t *testing.T) { - t.Parallel() - value := map[string]string{ - "foo": "bar", - } - key := "testValue" - session, _, cleanup := newTestSessionWithOptions(t, &Server{ - Handler: func(s Session) { - v := s.Context().Value(key).(map[string]string) - if v["foo"] != value["foo"] { - t.Fatalf("got %#v; want %#v", v, value) - } - }, - }, nil, PasswordAuth(func(ctx Context, password string) bool { - ctx.SetValue(key, value) - return true - })) - defer cleanup() - if err := session.Run(""); err != nil { - t.Fatal(err) - } -} diff --git a/tempfork/gliderlabs/ssh/doc.go b/tempfork/gliderlabs/ssh/doc.go deleted file mode 100644 index d139191768d55..0000000000000 --- a/tempfork/gliderlabs/ssh/doc.go +++ /dev/null @@ -1,45 +0,0 @@ -/* -Package ssh wraps the crypto/ssh package with a higher-level API for building -SSH servers. The goal of the API was to make it as simple as using net/http, so -the API is very similar. - -You should be able to build any SSH server using only this package, which wraps -relevant types and some functions from crypto/ssh. However, you still need to -use crypto/ssh for building SSH clients. - -ListenAndServe starts an SSH server with a given address, handler, and options. The -handler is usually nil, which means to use DefaultHandler. Handle sets DefaultHandler: - - ssh.Handle(func(s ssh.Session) { - io.WriteString(s, "Hello world\n") - }) - - log.Fatal(ssh.ListenAndServe(":2222", nil)) - -If you don't specify a host key, it will generate one every time. This is convenient -except you'll have to deal with clients being confused that the host key is different. -It's a better idea to generate or point to an existing key on your system: - - log.Fatal(ssh.ListenAndServe(":2222", nil, ssh.HostKeyFile("/Users/progrium/.ssh/id_rsa"))) - -Although all options have functional option helpers, another way to control the -server's behavior is by creating a custom Server: - - s := &ssh.Server{ - Addr: ":2222", - Handler: sessionHandler, - PublicKeyHandler: authHandler, - } - s.AddHostKey(hostKeySigner) - - log.Fatal(s.ListenAndServe()) - -This package automatically handles basic SSH requests like setting environment -variables, requesting PTY, and changing window size. These requests are -processed, responded to, and any relevant state is updated. This state is then -exposed to you via the Session interface. - -The one big feature missing from the Session abstraction is signals. This was -started, but not completed. Pull Requests welcome! -*/ -package ssh diff --git a/tempfork/gliderlabs/ssh/example_test.go b/tempfork/gliderlabs/ssh/example_test.go deleted file mode 100644 index c174bc4ae190e..0000000000000 --- a/tempfork/gliderlabs/ssh/example_test.go +++ /dev/null @@ -1,50 +0,0 @@ -package ssh_test - -import ( - "errors" - "io" - "os" - - "tailscale.com/tempfork/gliderlabs/ssh" -) - -func ExampleListenAndServe() { - ssh.ListenAndServe(":2222", func(s ssh.Session) { - io.WriteString(s, "Hello world\n") - }) -} - -func ExamplePasswordAuth() { - ssh.ListenAndServe(":2222", nil, - ssh.PasswordAuth(func(ctx ssh.Context, pass string) bool { - return pass == "secret" - }), - ) -} - -func ExampleNoPty() { - ssh.ListenAndServe(":2222", nil, ssh.NoPty()) -} - -func ExamplePublicKeyAuth() { - ssh.ListenAndServe(":2222", nil, - ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) error { - data, err := os.ReadFile("/path/to/allowed/key.pub") - if err != nil { - return err - } - allowed, _, _, _, err := ssh.ParseAuthorizedKey(data) - if err != nil { - return err - } - if !ssh.KeysEqual(key, allowed) { - return errors.New("some error") - } - return nil - }), - ) -} - -func ExampleHostKeyFile() { - ssh.ListenAndServe(":2222", nil, ssh.HostKeyFile("/path/to/host/key")) -} diff --git a/tempfork/gliderlabs/ssh/options.go b/tempfork/gliderlabs/ssh/options.go deleted file mode 100644 index 29c8ef141842b..0000000000000 --- a/tempfork/gliderlabs/ssh/options.go +++ /dev/null @@ -1,84 +0,0 @@ -package ssh - -import ( - "os" - - gossh "golang.org/x/crypto/ssh" -) - -// PasswordAuth returns a functional option that sets PasswordHandler on the server. -func PasswordAuth(fn PasswordHandler) Option { - return func(srv *Server) error { - srv.PasswordHandler = fn - return nil - } -} - -// PublicKeyAuth returns a functional option that sets PublicKeyHandler on the server. -func PublicKeyAuth(fn PublicKeyHandler) Option { - return func(srv *Server) error { - srv.PublicKeyHandler = fn - return nil - } -} - -// HostKeyFile returns a functional option that adds HostSigners to the server -// from a PEM file at filepath. -func HostKeyFile(filepath string) Option { - return func(srv *Server) error { - pemBytes, err := os.ReadFile(filepath) - if err != nil { - return err - } - - signer, err := gossh.ParsePrivateKey(pemBytes) - if err != nil { - return err - } - - srv.AddHostKey(signer) - - return nil - } -} - -func KeyboardInteractiveAuth(fn KeyboardInteractiveHandler) Option { - return func(srv *Server) error { - srv.KeyboardInteractiveHandler = fn - return nil - } -} - -// HostKeyPEM returns a functional option that adds HostSigners to the server -// from a PEM file as bytes. -func HostKeyPEM(bytes []byte) Option { - return func(srv *Server) error { - signer, err := gossh.ParsePrivateKey(bytes) - if err != nil { - return err - } - - srv.AddHostKey(signer) - - return nil - } -} - -// NoPty returns a functional option that sets PtyCallback to return false, -// denying PTY requests. -func NoPty() Option { - return func(srv *Server) error { - srv.PtyCallback = func(ctx Context, pty Pty) bool { - return false - } - return nil - } -} - -// WrapConn returns a functional option that sets ConnCallback on the server. -func WrapConn(fn ConnCallback) Option { - return func(srv *Server) error { - srv.ConnCallback = fn - return nil - } -} diff --git a/tempfork/gliderlabs/ssh/options_test.go b/tempfork/gliderlabs/ssh/options_test.go deleted file mode 100644 index 47342b0f67923..0000000000000 --- a/tempfork/gliderlabs/ssh/options_test.go +++ /dev/null @@ -1,111 +0,0 @@ -//go:build glidertests - -package ssh - -import ( - "net" - "strings" - "sync/atomic" - "testing" - - gossh "golang.org/x/crypto/ssh" -) - -func newTestSessionWithOptions(t *testing.T, srv *Server, cfg *gossh.ClientConfig, options ...Option) (*gossh.Session, *gossh.Client, func()) { - for _, option := range options { - if err := srv.SetOption(option); err != nil { - t.Fatal(err) - } - } - return newTestSession(t, srv, cfg) -} - -func TestPasswordAuth(t *testing.T) { - t.Parallel() - testUser := "testuser" - testPass := "testpass" - session, _, cleanup := newTestSessionWithOptions(t, &Server{ - Handler: func(s Session) { - // noop - }, - }, &gossh.ClientConfig{ - User: testUser, - Auth: []gossh.AuthMethod{ - gossh.Password(testPass), - }, - HostKeyCallback: gossh.InsecureIgnoreHostKey(), - }, PasswordAuth(func(ctx Context, password string) bool { - if ctx.User() != testUser { - t.Fatalf("user = %#v; want %#v", ctx.User(), testUser) - } - if password != testPass { - t.Fatalf("user = %#v; want %#v", password, testPass) - } - return true - })) - defer cleanup() - if err := session.Run(""); err != nil { - t.Fatal(err) - } -} - -func TestPasswordAuthBadPass(t *testing.T) { - t.Parallel() - l := newLocalListener() - srv := &Server{Handler: func(s Session) {}} - srv.SetOption(PasswordAuth(func(ctx Context, password string) bool { - return false - })) - go srv.serveOnce(l) - _, err := gossh.Dial("tcp", l.Addr().String(), &gossh.ClientConfig{ - User: "testuser", - Auth: []gossh.AuthMethod{ - gossh.Password("testpass"), - }, - HostKeyCallback: gossh.InsecureIgnoreHostKey(), - }) - if err != nil { - if !strings.Contains(err.Error(), "unable to authenticate") { - t.Fatal(err) - } - } -} - -type wrappedConn struct { - net.Conn - written int32 -} - -func (c *wrappedConn) Write(p []byte) (n int, err error) { - n, err = c.Conn.Write(p) - atomic.AddInt32(&(c.written), int32(n)) - return -} - -func TestConnWrapping(t *testing.T) { - t.Parallel() - var wrapped *wrappedConn - session, _, cleanup := newTestSessionWithOptions(t, &Server{ - Handler: func(s Session) { - // nothing - }, - }, &gossh.ClientConfig{ - User: "testuser", - Auth: []gossh.AuthMethod{ - gossh.Password("testpass"), - }, - HostKeyCallback: gossh.InsecureIgnoreHostKey(), - }, PasswordAuth(func(ctx Context, password string) bool { - return true - }), WrapConn(func(ctx Context, conn net.Conn) net.Conn { - wrapped = &wrappedConn{conn, 0} - return wrapped - })) - defer cleanup() - if err := session.Shell(); err != nil { - t.Fatal(err) - } - if atomic.LoadInt32(&(wrapped.written)) == 0 { - t.Fatal("wrapped conn not written to") - } -} diff --git a/tempfork/gliderlabs/ssh/server.go b/tempfork/gliderlabs/ssh/server.go deleted file mode 100644 index 473e5fbd6fc8f..0000000000000 --- a/tempfork/gliderlabs/ssh/server.go +++ /dev/null @@ -1,459 +0,0 @@ -package ssh - -import ( - "context" - "errors" - "fmt" - "net" - "sync" - "time" - - gossh "golang.org/x/crypto/ssh" -) - -// ErrServerClosed is returned by the Server's Serve, ListenAndServe, -// and ListenAndServeTLS methods after a call to Shutdown or Close. -var ErrServerClosed = errors.New("ssh: Server closed") - -type SubsystemHandler func(s Session) - -var DefaultSubsystemHandlers = map[string]SubsystemHandler{} - -type RequestHandler func(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte) - -var DefaultRequestHandlers = map[string]RequestHandler{} - -type ChannelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) - -var DefaultChannelHandlers = map[string]ChannelHandler{ - "session": DefaultSessionHandler, -} - -// Server defines parameters for running an SSH server. The zero value for -// Server is a valid configuration. When both PasswordHandler and -// PublicKeyHandler are nil, no client authentication is performed. -type Server struct { - Addr string // TCP address to listen on, ":22" if empty - Handler Handler // handler to invoke, ssh.DefaultHandler if nil - HostSigners []Signer // private keys for the host key, must have at least one - Version string // server version to be sent before the initial handshake - - KeyboardInteractiveHandler KeyboardInteractiveHandler // keyboard-interactive authentication handler - PasswordHandler PasswordHandler // password authentication handler - PublicKeyHandler PublicKeyHandler // public key authentication handler - NoClientAuthHandler NoClientAuthHandler // no client authentication handler - PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil - ConnCallback ConnCallback // optional callback for wrapping net.Conn before handling - LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil - ReversePortForwardingCallback ReversePortForwardingCallback // callback for allowing reverse port forwarding, denies all if nil - ServerConfigCallback ServerConfigCallback // callback for configuring detailed SSH options - SessionRequestCallback SessionRequestCallback // callback for allowing or denying SSH sessions - - ConnectionFailedCallback ConnectionFailedCallback // callback to report connection failures - - IdleTimeout time.Duration // connection timeout when no activity, none if empty - MaxTimeout time.Duration // absolute connection timeout, none if empty - - // ChannelHandlers allow overriding the built-in session handlers or provide - // extensions to the protocol, such as tcpip forwarding. By default only the - // "session" handler is enabled. - ChannelHandlers map[string]ChannelHandler - - // RequestHandlers allow overriding the server-level request handlers or - // provide extensions to the protocol, such as tcpip forwarding. By default - // no handlers are enabled. - RequestHandlers map[string]RequestHandler - - // SubsystemHandlers are handlers which are similar to the usual SSH command - // handlers, but handle named subsystems. - SubsystemHandlers map[string]SubsystemHandler - - listenerWg sync.WaitGroup - mu sync.RWMutex - listeners map[net.Listener]struct{} - conns map[*gossh.ServerConn]struct{} - connWg sync.WaitGroup - doneChan chan struct{} -} - -func (srv *Server) ensureHostSigner() error { - srv.mu.Lock() - defer srv.mu.Unlock() - - if len(srv.HostSigners) == 0 { - signer, err := generateSigner() - if err != nil { - return err - } - srv.HostSigners = append(srv.HostSigners, signer) - } - return nil -} - -func (srv *Server) ensureHandlers() { - srv.mu.Lock() - defer srv.mu.Unlock() - - if srv.RequestHandlers == nil { - srv.RequestHandlers = map[string]RequestHandler{} - for k, v := range DefaultRequestHandlers { - srv.RequestHandlers[k] = v - } - } - if srv.ChannelHandlers == nil { - srv.ChannelHandlers = map[string]ChannelHandler{} - for k, v := range DefaultChannelHandlers { - srv.ChannelHandlers[k] = v - } - } - if srv.SubsystemHandlers == nil { - srv.SubsystemHandlers = map[string]SubsystemHandler{} - for k, v := range DefaultSubsystemHandlers { - srv.SubsystemHandlers[k] = v - } - } -} - -func (srv *Server) config(ctx Context) *gossh.ServerConfig { - srv.mu.RLock() - defer srv.mu.RUnlock() - - var config *gossh.ServerConfig - if srv.ServerConfigCallback == nil { - config = &gossh.ServerConfig{} - } else { - config = srv.ServerConfigCallback(ctx) - } - for _, signer := range srv.HostSigners { - config.AddHostKey(signer) - } - if srv.PasswordHandler == nil && srv.PublicKeyHandler == nil && srv.KeyboardInteractiveHandler == nil { - config.NoClientAuth = true - } - if srv.Version != "" { - config.ServerVersion = "SSH-2.0-" + srv.Version - } - if srv.PasswordHandler != nil { - config.PasswordCallback = func(conn gossh.ConnMetadata, password []byte) (*gossh.Permissions, error) { - applyConnMetadata(ctx, conn) - if ok := srv.PasswordHandler(ctx, string(password)); !ok { - return ctx.Permissions().Permissions, fmt.Errorf("permission denied") - } - return ctx.Permissions().Permissions, nil - } - } - if srv.PublicKeyHandler != nil { - config.PublicKeyCallback = func(conn gossh.ConnMetadata, key gossh.PublicKey) (*gossh.Permissions, error) { - applyConnMetadata(ctx, conn) - if err := srv.PublicKeyHandler(ctx, key); err != nil { - return ctx.Permissions().Permissions, err - } - ctx.SetValue(ContextKeyPublicKey, key) - return ctx.Permissions().Permissions, nil - } - } - if srv.KeyboardInteractiveHandler != nil { - config.KeyboardInteractiveCallback = func(conn gossh.ConnMetadata, challenger gossh.KeyboardInteractiveChallenge) (*gossh.Permissions, error) { - applyConnMetadata(ctx, conn) - if ok := srv.KeyboardInteractiveHandler(ctx, challenger); !ok { - return ctx.Permissions().Permissions, fmt.Errorf("permission denied") - } - return ctx.Permissions().Permissions, nil - } - } - if srv.NoClientAuthHandler != nil { - config.NoClientAuthCallback = func(conn gossh.ConnMetadata) (*gossh.Permissions, error) { - applyConnMetadata(ctx, conn) - if err := srv.NoClientAuthHandler(ctx); err != nil { - return ctx.Permissions().Permissions, err - } - return ctx.Permissions().Permissions, nil - } - } - return config -} - -// Handle sets the Handler for the server. -func (srv *Server) Handle(fn Handler) { - srv.mu.Lock() - defer srv.mu.Unlock() - - srv.Handler = fn -} - -// Close immediately closes all active listeners and all active -// connections. -// -// Close returns any error returned from closing the Server's -// underlying Listener(s). -func (srv *Server) Close() error { - srv.mu.Lock() - defer srv.mu.Unlock() - - srv.closeDoneChanLocked() - err := srv.closeListenersLocked() - for c := range srv.conns { - c.Close() - delete(srv.conns, c) - } - return err -} - -// Shutdown gracefully shuts down the server without interrupting any -// active connections. Shutdown works by first closing all open -// listeners, and then waiting indefinitely for connections to close. -// If the provided context expires before the shutdown is complete, -// then the context's error is returned. -func (srv *Server) Shutdown(ctx context.Context) error { - srv.mu.Lock() - lnerr := srv.closeListenersLocked() - srv.closeDoneChanLocked() - srv.mu.Unlock() - - finished := make(chan struct{}, 1) - go func() { - srv.listenerWg.Wait() - srv.connWg.Wait() - finished <- struct{}{} - }() - - select { - case <-ctx.Done(): - return ctx.Err() - case <-finished: - return lnerr - } -} - -// Serve accepts incoming connections on the Listener l, creating a new -// connection goroutine for each. The connection goroutines read requests and then -// calls srv.Handler to handle sessions. -// -// Serve always returns a non-nil error. -func (srv *Server) Serve(l net.Listener) error { - srv.ensureHandlers() - defer l.Close() - if err := srv.ensureHostSigner(); err != nil { - return err - } - if srv.Handler == nil { - srv.Handler = DefaultHandler - } - var tempDelay time.Duration - - srv.trackListener(l, true) - defer srv.trackListener(l, false) - for { - conn, e := l.Accept() - if e != nil { - select { - case <-srv.getDoneChan(): - return ErrServerClosed - default: - } - if ne, ok := e.(net.Error); ok && ne.Temporary() { - if tempDelay == 0 { - tempDelay = 5 * time.Millisecond - } else { - tempDelay *= 2 - } - if max := 1 * time.Second; tempDelay > max { - tempDelay = max - } - time.Sleep(tempDelay) - continue - } - return e - } - go srv.HandleConn(conn) - } -} - -func (srv *Server) HandleConn(newConn net.Conn) { - ctx, cancel := newContext(srv) - if srv.ConnCallback != nil { - cbConn := srv.ConnCallback(ctx, newConn) - if cbConn == nil { - newConn.Close() - return - } - newConn = cbConn - } - conn := &serverConn{ - Conn: newConn, - idleTimeout: srv.IdleTimeout, - closeCanceler: cancel, - } - if srv.MaxTimeout > 0 { - conn.maxDeadline = time.Now().Add(srv.MaxTimeout) - } - defer conn.Close() - sshConn, chans, reqs, err := gossh.NewServerConn(conn, srv.config(ctx)) - if err != nil { - if srv.ConnectionFailedCallback != nil { - srv.ConnectionFailedCallback(conn, err) - } - return - } - - srv.trackConn(sshConn, true) - defer srv.trackConn(sshConn, false) - - ctx.SetValue(ContextKeyConn, sshConn) - applyConnMetadata(ctx, sshConn) - //go gossh.DiscardRequests(reqs) - go srv.handleRequests(ctx, reqs) - for ch := range chans { - handler := srv.ChannelHandlers[ch.ChannelType()] - if handler == nil { - handler = srv.ChannelHandlers["default"] - } - if handler == nil { - ch.Reject(gossh.UnknownChannelType, "unsupported channel type") - continue - } - go handler(srv, sshConn, ch, ctx) - } -} - -func (srv *Server) handleRequests(ctx Context, in <-chan *gossh.Request) { - for req := range in { - handler := srv.RequestHandlers[req.Type] - if handler == nil { - handler = srv.RequestHandlers["default"] - } - if handler == nil { - req.Reply(false, nil) - continue - } - /*reqCtx, cancel := context.WithCancel(ctx) - defer cancel() */ - ret, payload := handler(ctx, srv, req) - req.Reply(ret, payload) - } -} - -// ListenAndServe listens on the TCP network address srv.Addr and then calls -// Serve to handle incoming connections. If srv.Addr is blank, ":22" is used. -// ListenAndServe always returns a non-nil error. -func (srv *Server) ListenAndServe() error { - addr := srv.Addr - if addr == "" { - addr = ":22" - } - ln, err := net.Listen("tcp", addr) - if err != nil { - return err - } - return srv.Serve(ln) -} - -// AddHostKey adds a private key as a host key. If an existing host key exists -// with the same algorithm, it is overwritten. Each server config must have at -// least one host key. -func (srv *Server) AddHostKey(key Signer) { - srv.mu.Lock() - defer srv.mu.Unlock() - - // these are later added via AddHostKey on ServerConfig, which performs the - // check for one of every algorithm. - - // This check is based on the AddHostKey method from the x/crypto/ssh - // library. This allows us to only keep one active key for each type on a - // server at once. So, if you're dynamically updating keys at runtime, this - // list will not keep growing. - for i, k := range srv.HostSigners { - if k.PublicKey().Type() == key.PublicKey().Type() { - srv.HostSigners[i] = key - return - } - } - - srv.HostSigners = append(srv.HostSigners, key) -} - -// SetOption runs a functional option against the server. -func (srv *Server) SetOption(option Option) error { - // NOTE: there is a potential race here for any option that doesn't call an - // internal method. We can't actually lock here because if something calls - // (as an example) AddHostKey, it will deadlock. - - //srv.mu.Lock() - //defer srv.mu.Unlock() - - return option(srv) -} - -func (srv *Server) getDoneChan() <-chan struct{} { - srv.mu.Lock() - defer srv.mu.Unlock() - - return srv.getDoneChanLocked() -} - -func (srv *Server) getDoneChanLocked() chan struct{} { - if srv.doneChan == nil { - srv.doneChan = make(chan struct{}) - } - return srv.doneChan -} - -func (srv *Server) closeDoneChanLocked() { - ch := srv.getDoneChanLocked() - select { - case <-ch: - // Already closed. Don't close again. - default: - // Safe to close here. We're the only closer, guarded - // by srv.mu. - close(ch) - } -} - -func (srv *Server) closeListenersLocked() error { - var err error - for ln := range srv.listeners { - if cerr := ln.Close(); cerr != nil && err == nil { - err = cerr - } - delete(srv.listeners, ln) - } - return err -} - -func (srv *Server) trackListener(ln net.Listener, add bool) { - srv.mu.Lock() - defer srv.mu.Unlock() - - if srv.listeners == nil { - srv.listeners = make(map[net.Listener]struct{}) - } - if add { - // If the *Server is being reused after a previous - // Close or Shutdown, reset its doneChan: - if len(srv.listeners) == 0 && len(srv.conns) == 0 { - srv.doneChan = nil - } - srv.listeners[ln] = struct{}{} - srv.listenerWg.Add(1) - } else { - delete(srv.listeners, ln) - srv.listenerWg.Done() - } -} - -func (srv *Server) trackConn(c *gossh.ServerConn, add bool) { - srv.mu.Lock() - defer srv.mu.Unlock() - - if srv.conns == nil { - srv.conns = make(map[*gossh.ServerConn]struct{}) - } - if add { - srv.conns[c] = struct{}{} - srv.connWg.Add(1) - } else { - delete(srv.conns, c) - srv.connWg.Done() - } -} diff --git a/tempfork/gliderlabs/ssh/server_test.go b/tempfork/gliderlabs/ssh/server_test.go deleted file mode 100644 index 177c071170c4e..0000000000000 --- a/tempfork/gliderlabs/ssh/server_test.go +++ /dev/null @@ -1,128 +0,0 @@ -//go:build glidertests - -package ssh - -import ( - "bytes" - "context" - "io" - "testing" - "time" -) - -func TestAddHostKey(t *testing.T) { - s := Server{} - signer, err := generateSigner() - if err != nil { - t.Fatal(err) - } - s.AddHostKey(signer) - if len(s.HostSigners) != 1 { - t.Fatal("Key was not properly added") - } - signer, err = generateSigner() - if err != nil { - t.Fatal(err) - } - s.AddHostKey(signer) - if len(s.HostSigners) != 1 { - t.Fatal("Key was not properly replaced") - } -} - -func TestServerShutdown(t *testing.T) { - l := newLocalListener() - testBytes := []byte("Hello world\n") - s := &Server{ - Handler: func(s Session) { - s.Write(testBytes) - time.Sleep(50 * time.Millisecond) - }, - } - go func() { - err := s.Serve(l) - if err != nil && err != ErrServerClosed { - t.Fatal(err) - } - }() - sessDone := make(chan struct{}) - sess, _, cleanup := newClientSession(t, l.Addr().String(), nil) - go func() { - defer cleanup() - defer close(sessDone) - var stdout bytes.Buffer - sess.Stdout = &stdout - if err := sess.Run(""); err != nil { - t.Fatal(err) - } - if !bytes.Equal(stdout.Bytes(), testBytes) { - t.Fatalf("expected = %s; got %s", testBytes, stdout.Bytes()) - } - }() - - srvDone := make(chan struct{}) - go func() { - defer close(srvDone) - err := s.Shutdown(context.Background()) - if err != nil { - t.Fatal(err) - } - }() - - timeout := time.After(2 * time.Second) - select { - case <-timeout: - t.Fatal("timeout") - return - case <-srvDone: - // TODO: add timeout for sessDone - <-sessDone - return - } -} - -func TestServerClose(t *testing.T) { - l := newLocalListener() - s := &Server{ - Handler: func(s Session) { - time.Sleep(5 * time.Second) - }, - } - go func() { - err := s.Serve(l) - if err != nil && err != ErrServerClosed { - t.Fatal(err) - } - }() - - clientDoneChan := make(chan struct{}) - closeDoneChan := make(chan struct{}) - - sess, _, cleanup := newClientSession(t, l.Addr().String(), nil) - go func() { - defer cleanup() - defer close(clientDoneChan) - <-closeDoneChan - if err := sess.Run(""); err != nil && err != io.EOF { - t.Fatal(err) - } - }() - - go func() { - err := s.Close() - if err != nil { - t.Fatal(err) - } - close(closeDoneChan) - }() - - timeout := time.After(100 * time.Millisecond) - select { - case <-timeout: - t.Error("timeout") - return - case <-s.getDoneChan(): - <-clientDoneChan - return - } -} diff --git a/tempfork/gliderlabs/ssh/session.go b/tempfork/gliderlabs/ssh/session.go deleted file mode 100644 index a7a9a3eebd96f..0000000000000 --- a/tempfork/gliderlabs/ssh/session.go +++ /dev/null @@ -1,386 +0,0 @@ -package ssh - -import ( - "bytes" - "context" - "errors" - "fmt" - "net" - "sync" - - "github.com/anmitsu/go-shlex" - gossh "golang.org/x/crypto/ssh" -) - -// Session provides access to information about an SSH session and methods -// to read and write to the SSH channel with an embedded Channel interface from -// crypto/ssh. -// -// When Command() returns an empty slice, the user requested a shell. Otherwise -// the user is performing an exec with those command arguments. -// -// TODO: Signals -type Session interface { - gossh.Channel - - // User returns the username used when establishing the SSH connection. - User() string - - // RemoteAddr returns the net.Addr of the client side of the connection. - RemoteAddr() net.Addr - - // LocalAddr returns the net.Addr of the server side of the connection. - LocalAddr() net.Addr - - // Environ returns a copy of strings representing the environment set by the - // user for this session, in the form "key=value". - Environ() []string - - // Exit sends an exit status and then closes the session. - Exit(code int) error - - // Command returns a shell parsed slice of arguments that were provided by the - // user. Shell parsing splits the command string according to POSIX shell rules, - // which considers quoting not just whitespace. - Command() []string - - // RawCommand returns the exact command that was provided by the user. - RawCommand() string - - // Subsystem returns the subsystem requested by the user. - Subsystem() string - - // PublicKey returns the PublicKey used to authenticate. If a public key was not - // used it will return nil. - PublicKey() PublicKey - - // Context returns the connection's context. The returned context is always - // non-nil and holds the same data as the Context passed into auth - // handlers and callbacks. - // - // The context is canceled when the client's connection closes or I/O - // operation fails. - Context() context.Context - - // Permissions returns a copy of the Permissions object that was available for - // setup in the auth handlers via the Context. - Permissions() Permissions - - // Pty returns PTY information, a channel of window size changes, and a boolean - // of whether or not a PTY was accepted for this session. - Pty() (Pty, <-chan Window, bool) - - // Signals registers a channel to receive signals sent from the client. The - // channel must handle signal sends or it will block the SSH request loop. - // Registering nil will unregister the channel from signal sends. During the - // time no channel is registered signals are buffered up to a reasonable amount. - // If there are buffered signals when a channel is registered, they will be - // sent in order on the channel immediately after registering. - Signals(c chan<- Signal) - - // Break regisers a channel to receive notifications of break requests sent - // from the client. The channel must handle break requests, or it will block - // the request handling loop. Registering nil will unregister the channel. - // During the time that no channel is registered, breaks are ignored. - Break(c chan<- bool) - - // DisablePTYEmulation disables the session's default minimal PTY emulation. - // If you're setting the pty's termios settings from the Pty request, use - // this method to avoid corruption. - // Currently (2022-03-12) the only emulation implemented is NL-to-CRNL translation (`\n`=>`\r\n`). - // A call of DisablePTYEmulation must precede any call to Write. - DisablePTYEmulation() -} - -// maxSigBufSize is how many signals will be buffered -// when there is no signal channel specified -const maxSigBufSize = 128 - -func DefaultSessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) { - ch, reqs, err := newChan.Accept() - if err != nil { - // TODO: trigger event callback - return - } - sess := &session{ - Channel: ch, - conn: conn, - handler: srv.Handler, - ptyCb: srv.PtyCallback, - sessReqCb: srv.SessionRequestCallback, - subsystemHandlers: srv.SubsystemHandlers, - ctx: ctx, - } - sess.handleRequests(reqs) -} - -type session struct { - sync.Mutex - gossh.Channel - conn *gossh.ServerConn - handler Handler - subsystemHandlers map[string]SubsystemHandler - handled bool - exited bool - pty *Pty - winch chan Window - env []string - ptyCb PtyCallback - sessReqCb SessionRequestCallback - rawCmd string - subsystem string - ctx Context - sigCh chan<- Signal - sigBuf []Signal - breakCh chan<- bool - disablePtyEmulation bool -} - -func (sess *session) DisablePTYEmulation() { - sess.disablePtyEmulation = true -} - -func (sess *session) Write(p []byte) (n int, err error) { - if sess.pty != nil && !sess.disablePtyEmulation { - m := len(p) - // normalize \n to \r\n when pty is accepted. - // this is a hardcoded shortcut since we don't support terminal modes. - p = bytes.Replace(p, []byte{'\n'}, []byte{'\r', '\n'}, -1) - p = bytes.Replace(p, []byte{'\r', '\r', '\n'}, []byte{'\r', '\n'}, -1) - n, err = sess.Channel.Write(p) - if n > m { - n = m - } - return - } - return sess.Channel.Write(p) -} - -func (sess *session) PublicKey() PublicKey { - sessionkey := sess.ctx.Value(ContextKeyPublicKey) - if sessionkey == nil { - return nil - } - return sessionkey.(PublicKey) -} - -func (sess *session) Permissions() Permissions { - // use context permissions because its properly - // wrapped and easier to dereference - perms := sess.ctx.Value(ContextKeyPermissions).(*Permissions) - return *perms -} - -func (sess *session) Context() context.Context { - return sess.ctx -} - -func (sess *session) Exit(code int) error { - sess.Lock() - defer sess.Unlock() - if sess.exited { - return errors.New("Session.Exit called multiple times") - } - sess.exited = true - - status := struct{ Status uint32 }{uint32(code)} - _, err := sess.SendRequest("exit-status", false, gossh.Marshal(&status)) - if err != nil { - return err - } - return sess.Close() -} - -func (sess *session) User() string { - return sess.conn.User() -} - -func (sess *session) RemoteAddr() net.Addr { - return sess.conn.RemoteAddr() -} - -func (sess *session) LocalAddr() net.Addr { - return sess.conn.LocalAddr() -} - -func (sess *session) Environ() []string { - return append([]string(nil), sess.env...) -} - -func (sess *session) RawCommand() string { - return sess.rawCmd -} - -func (sess *session) Command() []string { - cmd, _ := shlex.Split(sess.rawCmd, true) - return append([]string(nil), cmd...) -} - -func (sess *session) Subsystem() string { - return sess.subsystem -} - -func (sess *session) Pty() (Pty, <-chan Window, bool) { - if sess.pty != nil { - return *sess.pty, sess.winch, true - } - return Pty{}, sess.winch, false -} - -func (sess *session) Signals(c chan<- Signal) { - sess.Lock() - defer sess.Unlock() - sess.sigCh = c - if len(sess.sigBuf) > 0 { - go func() { - for _, sig := range sess.sigBuf { - sess.sigCh <- sig - } - }() - } -} - -func (sess *session) Break(c chan<- bool) { - sess.Lock() - defer sess.Unlock() - sess.breakCh = c -} - -func (sess *session) handleRequests(reqs <-chan *gossh.Request) { - for req := range reqs { - switch req.Type { - case "shell", "exec": - if sess.handled { - req.Reply(false, nil) - continue - } - - var payload = struct{ Value string }{} - gossh.Unmarshal(req.Payload, &payload) - sess.rawCmd = payload.Value - - // If there's a session policy callback, we need to confirm before - // accepting the session. - if sess.sessReqCb != nil && !sess.sessReqCb(sess, req.Type) { - sess.rawCmd = "" - req.Reply(false, nil) - continue - } - - sess.handled = true - req.Reply(true, nil) - - go func() { - sess.handler(sess) - sess.Exit(0) - }() - case "subsystem": - if sess.handled { - req.Reply(false, nil) - continue - } - - var payload = struct{ Value string }{} - gossh.Unmarshal(req.Payload, &payload) - sess.subsystem = payload.Value - - // If there's a session policy callback, we need to confirm before - // accepting the session. - if sess.sessReqCb != nil && !sess.sessReqCb(sess, req.Type) { - sess.rawCmd = "" - req.Reply(false, nil) - continue - } - - handler := sess.subsystemHandlers[payload.Value] - if handler == nil { - handler = sess.subsystemHandlers["default"] - } - if handler == nil { - req.Reply(false, nil) - continue - } - - sess.handled = true - req.Reply(true, nil) - - go func() { - handler(sess) - sess.Exit(0) - }() - case "env": - if sess.handled { - req.Reply(false, nil) - continue - } - var kv struct{ Key, Value string } - gossh.Unmarshal(req.Payload, &kv) - sess.env = append(sess.env, fmt.Sprintf("%s=%s", kv.Key, kv.Value)) - req.Reply(true, nil) - case "signal": - var payload struct{ Signal string } - gossh.Unmarshal(req.Payload, &payload) - sess.Lock() - if sess.sigCh != nil { - sess.sigCh <- Signal(payload.Signal) - } else { - if len(sess.sigBuf) < maxSigBufSize { - sess.sigBuf = append(sess.sigBuf, Signal(payload.Signal)) - } - } - sess.Unlock() - case "pty-req": - if sess.handled || sess.pty != nil { - req.Reply(false, nil) - continue - } - ptyReq, ok := parsePtyRequest(req.Payload) - if !ok { - req.Reply(false, nil) - continue - } - if sess.ptyCb != nil { - ok := sess.ptyCb(sess.ctx, ptyReq) - if !ok { - req.Reply(false, nil) - continue - } - } - sess.pty = &ptyReq - sess.winch = make(chan Window, 1) - sess.winch <- ptyReq.Window - defer func() { - // when reqs is closed - close(sess.winch) - }() - req.Reply(ok, nil) - case "window-change": - if sess.pty == nil { - req.Reply(false, nil) - continue - } - win, _, ok := parseWindow(req.Payload) - if ok { - sess.pty.Window = win - sess.winch <- win - } - req.Reply(ok, nil) - case agentRequestType: - // TODO: option/callback to allow agent forwarding - SetAgentRequested(sess.ctx) - req.Reply(true, nil) - case "break": - ok := false - sess.Lock() - if sess.breakCh != nil { - sess.breakCh <- true - ok = true - } - req.Reply(ok, nil) - sess.Unlock() - default: - // TODO: debug log - req.Reply(false, nil) - } - } -} diff --git a/tempfork/gliderlabs/ssh/session_test.go b/tempfork/gliderlabs/ssh/session_test.go deleted file mode 100644 index fe61a9d96be9b..0000000000000 --- a/tempfork/gliderlabs/ssh/session_test.go +++ /dev/null @@ -1,440 +0,0 @@ -//go:build glidertests - -package ssh - -import ( - "bytes" - "fmt" - "io" - "net" - "testing" - - gossh "golang.org/x/crypto/ssh" -) - -func (srv *Server) serveOnce(l net.Listener) error { - srv.ensureHandlers() - if err := srv.ensureHostSigner(); err != nil { - return err - } - conn, e := l.Accept() - if e != nil { - return e - } - srv.ChannelHandlers = map[string]ChannelHandler{ - "session": DefaultSessionHandler, - "direct-tcpip": DirectTCPIPHandler, - } - srv.HandleConn(conn) - return nil -} - -func newLocalListener() net.Listener { - l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - if l, err = net.Listen("tcp6", "[::1]:0"); err != nil { - panic(fmt.Sprintf("failed to listen on a port: %v", err)) - } - } - return l -} - -func newClientSession(t *testing.T, addr string, config *gossh.ClientConfig) (*gossh.Session, *gossh.Client, func()) { - if config == nil { - config = &gossh.ClientConfig{ - User: "testuser", - Auth: []gossh.AuthMethod{ - gossh.Password("testpass"), - }, - } - } - if config.HostKeyCallback == nil { - config.HostKeyCallback = gossh.InsecureIgnoreHostKey() - } - client, err := gossh.Dial("tcp", addr, config) - if err != nil { - t.Fatal(err) - } - session, err := client.NewSession() - if err != nil { - t.Fatal(err) - } - return session, client, func() { - session.Close() - client.Close() - } -} - -func newTestSession(t *testing.T, srv *Server, cfg *gossh.ClientConfig) (*gossh.Session, *gossh.Client, func()) { - l := newLocalListener() - go srv.serveOnce(l) - return newClientSession(t, l.Addr().String(), cfg) -} - -func TestStdout(t *testing.T) { - t.Parallel() - testBytes := []byte("Hello world\n") - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - s.Write(testBytes) - }, - }, nil) - defer cleanup() - var stdout bytes.Buffer - session.Stdout = &stdout - if err := session.Run(""); err != nil { - t.Fatal(err) - } - if !bytes.Equal(stdout.Bytes(), testBytes) { - t.Fatalf("stdout = %#v; want %#v", stdout.Bytes(), testBytes) - } -} - -func TestStderr(t *testing.T) { - t.Parallel() - testBytes := []byte("Hello world\n") - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - s.Stderr().Write(testBytes) - }, - }, nil) - defer cleanup() - var stderr bytes.Buffer - session.Stderr = &stderr - if err := session.Run(""); err != nil { - t.Fatal(err) - } - if !bytes.Equal(stderr.Bytes(), testBytes) { - t.Fatalf("stderr = %#v; want %#v", stderr.Bytes(), testBytes) - } -} - -func TestStdin(t *testing.T) { - t.Parallel() - testBytes := []byte("Hello world\n") - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - io.Copy(s, s) // stdin back into stdout - }, - }, nil) - defer cleanup() - var stdout bytes.Buffer - session.Stdout = &stdout - session.Stdin = bytes.NewBuffer(testBytes) - if err := session.Run(""); err != nil { - t.Fatal(err) - } - if !bytes.Equal(stdout.Bytes(), testBytes) { - t.Fatalf("stdout = %#v; want %#v given stdin = %#v", stdout.Bytes(), testBytes, testBytes) - } -} - -func TestUser(t *testing.T) { - t.Parallel() - testUser := []byte("progrium") - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - io.WriteString(s, s.User()) - }, - }, &gossh.ClientConfig{ - User: string(testUser), - }) - defer cleanup() - var stdout bytes.Buffer - session.Stdout = &stdout - if err := session.Run(""); err != nil { - t.Fatal(err) - } - if !bytes.Equal(stdout.Bytes(), testUser) { - t.Fatalf("stdout = %#v; want %#v given user = %#v", stdout.Bytes(), testUser, string(testUser)) - } -} - -func TestDefaultExitStatusZero(t *testing.T) { - t.Parallel() - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - // noop - }, - }, nil) - defer cleanup() - err := session.Run("") - if err != nil { - t.Fatalf("expected nil but got %v", err) - } -} - -func TestExplicitExitStatusZero(t *testing.T) { - t.Parallel() - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - s.Exit(0) - }, - }, nil) - defer cleanup() - err := session.Run("") - if err != nil { - t.Fatalf("expected nil but got %v", err) - } -} - -func TestExitStatusNonZero(t *testing.T) { - t.Parallel() - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - s.Exit(1) - }, - }, nil) - defer cleanup() - err := session.Run("") - e, ok := err.(*gossh.ExitError) - if !ok { - t.Fatalf("expected ExitError but got %T", err) - } - if e.ExitStatus() != 1 { - t.Fatalf("exit-status = %#v; want %#v", e.ExitStatus(), 1) - } -} - -func TestPty(t *testing.T) { - t.Parallel() - term := "xterm" - winWidth := 40 - winHeight := 80 - done := make(chan bool) - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - ptyReq, _, isPty := s.Pty() - if !isPty { - t.Fatalf("expected pty but none requested") - } - if ptyReq.Term != term { - t.Fatalf("expected term %#v but got %#v", term, ptyReq.Term) - } - if ptyReq.Window.Width != winWidth { - t.Fatalf("expected window width %#v but got %#v", winWidth, ptyReq.Window.Width) - } - if ptyReq.Window.Height != winHeight { - t.Fatalf("expected window height %#v but got %#v", winHeight, ptyReq.Window.Height) - } - close(done) - }, - }, nil) - defer cleanup() - if err := session.RequestPty(term, winHeight, winWidth, gossh.TerminalModes{}); err != nil { - t.Fatalf("expected nil but got %v", err) - } - if err := session.Shell(); err != nil { - t.Fatalf("expected nil but got %v", err) - } - <-done -} - -func TestPtyResize(t *testing.T) { - t.Parallel() - winch0 := Window{Width: 40, Height: 80} - winch1 := Window{Width: 80, Height: 160} - winch2 := Window{Width: 20, Height: 40} - winches := make(chan Window) - done := make(chan bool) - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - ptyReq, winCh, isPty := s.Pty() - if !isPty { - t.Fatalf("expected pty but none requested") - } - if ptyReq.Window != winch0 { - t.Fatalf("expected window %#v but got %#v", winch0, ptyReq.Window) - } - for win := range winCh { - winches <- win - } - close(done) - }, - }, nil) - defer cleanup() - // winch0 - if err := session.RequestPty("xterm", winch0.Height, winch0.Width, gossh.TerminalModes{}); err != nil { - t.Fatalf("expected nil but got %v", err) - } - if err := session.Shell(); err != nil { - t.Fatalf("expected nil but got %v", err) - } - gotWinch := <-winches - if gotWinch != winch0 { - t.Fatalf("expected window %#v but got %#v", winch0, gotWinch) - } - // winch1 - winchMsg := struct{ w, h uint32 }{uint32(winch1.Width), uint32(winch1.Height)} - ok, err := session.SendRequest("window-change", true, gossh.Marshal(&winchMsg)) - if err == nil && !ok { - t.Fatalf("unexpected error or bad reply on send request") - } - gotWinch = <-winches - if gotWinch != winch1 { - t.Fatalf("expected window %#v but got %#v", winch1, gotWinch) - } - // winch2 - winchMsg = struct{ w, h uint32 }{uint32(winch2.Width), uint32(winch2.Height)} - ok, err = session.SendRequest("window-change", true, gossh.Marshal(&winchMsg)) - if err == nil && !ok { - t.Fatalf("unexpected error or bad reply on send request") - } - gotWinch = <-winches - if gotWinch != winch2 { - t.Fatalf("expected window %#v but got %#v", winch2, gotWinch) - } - session.Close() - <-done -} - -func TestSignals(t *testing.T) { - t.Parallel() - - // errChan lets us get errors back from the session - errChan := make(chan error, 5) - - // doneChan lets us specify that we should exit. - doneChan := make(chan interface{}) - - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - // We need to use a buffered channel here, otherwise it's possible for the - // second call to Signal to get discarded. - signals := make(chan Signal, 2) - s.Signals(signals) - - select { - case sig := <-signals: - if sig != SIGINT { - errChan <- fmt.Errorf("expected signal %v but got %v", SIGINT, sig) - return - } - case <-doneChan: - errChan <- fmt.Errorf("Unexpected done") - return - } - - select { - case sig := <-signals: - if sig != SIGKILL { - errChan <- fmt.Errorf("expected signal %v but got %v", SIGKILL, sig) - return - } - case <-doneChan: - errChan <- fmt.Errorf("Unexpected done") - return - } - }, - }, nil) - defer cleanup() - - go func() { - session.Signal(gossh.SIGINT) - session.Signal(gossh.SIGKILL) - }() - - go func() { - errChan <- session.Run("") - }() - - err := <-errChan - close(doneChan) - - if err != nil { - t.Fatalf("expected nil but got %v", err) - } -} - -func TestBreakWithChanRegistered(t *testing.T) { - t.Parallel() - - // errChan lets us get errors back from the session - errChan := make(chan error, 5) - - // doneChan lets us specify that we should exit. - doneChan := make(chan interface{}) - - breakChan := make(chan bool) - - readyToReceiveBreak := make(chan bool) - - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - s.Break(breakChan) // register a break channel with the session - readyToReceiveBreak <- true - - select { - case <-breakChan: - io.WriteString(s, "break") - case <-doneChan: - errChan <- fmt.Errorf("Unexpected done") - return - } - }, - }, nil) - defer cleanup() - var stdout bytes.Buffer - session.Stdout = &stdout - go func() { - errChan <- session.Run("") - }() - - <-readyToReceiveBreak - ok, err := session.SendRequest("break", true, nil) - if err != nil { - t.Fatalf("expected nil but got %v", err) - } - if ok != true { - t.Fatalf("expected true but got %v", ok) - } - - err = <-errChan - close(doneChan) - - if err != nil { - t.Fatalf("expected nil but got %v", err) - } - if !bytes.Equal(stdout.Bytes(), []byte("break")) { - t.Fatalf("stdout = %#v, expected 'break'", stdout.Bytes()) - } -} - -func TestBreakWithoutChanRegistered(t *testing.T) { - t.Parallel() - - // errChan lets us get errors back from the session - errChan := make(chan error, 5) - - // doneChan lets us specify that we should exit. - doneChan := make(chan interface{}) - - waitUntilAfterBreakSent := make(chan bool) - - session, _, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) { - <-waitUntilAfterBreakSent - }, - }, nil) - defer cleanup() - var stdout bytes.Buffer - session.Stdout = &stdout - go func() { - errChan <- session.Run("") - }() - - ok, err := session.SendRequest("break", true, nil) - if err != nil { - t.Fatalf("expected nil but got %v", err) - } - if ok != false { - t.Fatalf("expected false but got %v", ok) - } - waitUntilAfterBreakSent <- true - - err = <-errChan - close(doneChan) - if err != nil { - t.Fatalf("expected nil but got %v", err) - } -} diff --git a/tempfork/gliderlabs/ssh/ssh.go b/tempfork/gliderlabs/ssh/ssh.go deleted file mode 100644 index 54bd31ec2fcb4..0000000000000 --- a/tempfork/gliderlabs/ssh/ssh.go +++ /dev/null @@ -1,156 +0,0 @@ -package ssh - -import ( - "crypto/subtle" - "net" - - gossh "golang.org/x/crypto/ssh" -) - -type Signal string - -// POSIX signals as listed in RFC 4254 Section 6.10. -const ( - SIGABRT Signal = "ABRT" - SIGALRM Signal = "ALRM" - SIGFPE Signal = "FPE" - SIGHUP Signal = "HUP" - SIGILL Signal = "ILL" - SIGINT Signal = "INT" - SIGKILL Signal = "KILL" - SIGPIPE Signal = "PIPE" - SIGQUIT Signal = "QUIT" - SIGSEGV Signal = "SEGV" - SIGTERM Signal = "TERM" - SIGUSR1 Signal = "USR1" - SIGUSR2 Signal = "USR2" -) - -// DefaultHandler is the default Handler used by Serve. -var DefaultHandler Handler - -// Option is a functional option handler for Server. -type Option func(*Server) error - -// Handler is a callback for handling established SSH sessions. -type Handler func(Session) - -// PublicKeyHandler is a callback for performing public key authentication. -type PublicKeyHandler func(ctx Context, key PublicKey) error - -type NoClientAuthHandler func(ctx Context) error - -type BannerHandler func(ctx Context) string - -// PasswordHandler is a callback for performing password authentication. -type PasswordHandler func(ctx Context, password string) bool - -// KeyboardInteractiveHandler is a callback for performing keyboard-interactive authentication. -type KeyboardInteractiveHandler func(ctx Context, challenger gossh.KeyboardInteractiveChallenge) bool - -// PtyCallback is a hook for allowing PTY sessions. -type PtyCallback func(ctx Context, pty Pty) bool - -// SessionRequestCallback is a callback for allowing or denying SSH sessions. -type SessionRequestCallback func(sess Session, requestType string) bool - -// ConnCallback is a hook for new connections before handling. -// It allows wrapping for timeouts and limiting by returning -// the net.Conn that will be used as the underlying connection. -type ConnCallback func(ctx Context, conn net.Conn) net.Conn - -// LocalPortForwardingCallback is a hook for allowing port forwarding -type LocalPortForwardingCallback func(ctx Context, destinationHost string, destinationPort uint32) bool - -// ReversePortForwardingCallback is a hook for allowing reverse port forwarding -type ReversePortForwardingCallback func(ctx Context, bindHost string, bindPort uint32) bool - -// ServerConfigCallback is a hook for creating custom default server configs -type ServerConfigCallback func(ctx Context) *gossh.ServerConfig - -// ConnectionFailedCallback is a hook for reporting failed connections -// Please note: the net.Conn is likely to be closed at this point -type ConnectionFailedCallback func(conn net.Conn, err error) - -// Window represents the size of a PTY window. -// -// See https://datatracker.ietf.org/doc/html/rfc4254#section-6.2 -// -// Zero dimension parameters MUST be ignored. The character/row dimensions -// override the pixel dimensions (when nonzero). Pixel dimensions refer -// to the drawable area of the window. -type Window struct { - // Width is the number of columns. - // It overrides WidthPixels. - Width int - // Height is the number of rows. - // It overrides HeightPixels. - Height int - - // WidthPixels is the drawable width of the window, in pixels. - WidthPixels int - // HeightPixels is the drawable height of the window, in pixels. - HeightPixels int -} - -// Pty represents a PTY request and configuration. -type Pty struct { - // Term is the TERM environment variable value. - Term string - - // Window is the Window sent as part of the pty-req. - Window Window - - // Modes represent a mapping of Terminal Mode opcode to value as it was - // requested by the client as part of the pty-req. These are outlined as - // part of https://datatracker.ietf.org/doc/html/rfc4254#section-8. - // - // The opcodes are defined as constants in golang.org/x/crypto/ssh (VINTR,VQUIT,etc.). - // Boolean opcodes have values 0 or 1. - Modes gossh.TerminalModes -} - -// Serve accepts incoming SSH connections on the listener l, creating a new -// connection goroutine for each. The connection goroutines read requests and -// then calls handler to handle sessions. Handler is typically nil, in which -// case the DefaultHandler is used. -func Serve(l net.Listener, handler Handler, options ...Option) error { - srv := &Server{Handler: handler} - for _, option := range options { - if err := srv.SetOption(option); err != nil { - return err - } - } - return srv.Serve(l) -} - -// ListenAndServe listens on the TCP network address addr and then calls Serve -// with handler to handle sessions on incoming connections. Handler is typically -// nil, in which case the DefaultHandler is used. -func ListenAndServe(addr string, handler Handler, options ...Option) error { - srv := &Server{Addr: addr, Handler: handler} - for _, option := range options { - if err := srv.SetOption(option); err != nil { - return err - } - } - return srv.ListenAndServe() -} - -// Handle registers the handler as the DefaultHandler. -func Handle(handler Handler) { - DefaultHandler = handler -} - -// KeysEqual is constant time compare of the keys to avoid timing attacks. -func KeysEqual(ak, bk PublicKey) bool { - - //avoid panic if one of the keys is nil, return false instead - if ak == nil || bk == nil { - return false - } - - a := ak.Marshal() - b := bk.Marshal() - return (len(a) == len(b) && subtle.ConstantTimeCompare(a, b) == 1) -} diff --git a/tempfork/gliderlabs/ssh/ssh_test.go b/tempfork/gliderlabs/ssh/ssh_test.go deleted file mode 100644 index aa301b0489f21..0000000000000 --- a/tempfork/gliderlabs/ssh/ssh_test.go +++ /dev/null @@ -1,17 +0,0 @@ -package ssh - -import ( - "testing" -) - -func TestKeysEqual(t *testing.T) { - defer func() { - if r := recover(); r != nil { - t.Errorf("The code did panic") - } - }() - - if KeysEqual(nil, nil) { - t.Error("two nil keys should not return true") - } -} diff --git a/tempfork/gliderlabs/ssh/tcpip.go b/tempfork/gliderlabs/ssh/tcpip.go deleted file mode 100644 index 335fda65754ea..0000000000000 --- a/tempfork/gliderlabs/ssh/tcpip.go +++ /dev/null @@ -1,193 +0,0 @@ -package ssh - -import ( - "io" - "log" - "net" - "strconv" - "sync" - - gossh "golang.org/x/crypto/ssh" -) - -const ( - forwardedTCPChannelType = "forwarded-tcpip" -) - -// direct-tcpip data struct as specified in RFC4254, Section 7.2 -type localForwardChannelData struct { - DestAddr string - DestPort uint32 - - OriginAddr string - OriginPort uint32 -} - -// DirectTCPIPHandler can be enabled by adding it to the server's -// ChannelHandlers under direct-tcpip. -func DirectTCPIPHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) { - d := localForwardChannelData{} - if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil { - newChan.Reject(gossh.ConnectionFailed, "error parsing forward data: "+err.Error()) - return - } - - if srv.LocalPortForwardingCallback == nil || !srv.LocalPortForwardingCallback(ctx, d.DestAddr, d.DestPort) { - newChan.Reject(gossh.Prohibited, "port forwarding is disabled") - return - } - - dest := net.JoinHostPort(d.DestAddr, strconv.FormatInt(int64(d.DestPort), 10)) - - var dialer net.Dialer - dconn, err := dialer.DialContext(ctx, "tcp", dest) - if err != nil { - newChan.Reject(gossh.ConnectionFailed, err.Error()) - return - } - - ch, reqs, err := newChan.Accept() - if err != nil { - dconn.Close() - return - } - go gossh.DiscardRequests(reqs) - - go func() { - defer ch.Close() - defer dconn.Close() - io.Copy(ch, dconn) - }() - go func() { - defer ch.Close() - defer dconn.Close() - io.Copy(dconn, ch) - }() -} - -type remoteForwardRequest struct { - BindAddr string - BindPort uint32 -} - -type remoteForwardSuccess struct { - BindPort uint32 -} - -type remoteForwardCancelRequest struct { - BindAddr string - BindPort uint32 -} - -type remoteForwardChannelData struct { - DestAddr string - DestPort uint32 - OriginAddr string - OriginPort uint32 -} - -// ForwardedTCPHandler can be enabled by creating a ForwardedTCPHandler and -// adding the HandleSSHRequest callback to the server's RequestHandlers under -// tcpip-forward and cancel-tcpip-forward. -type ForwardedTCPHandler struct { - forwards map[string]net.Listener - sync.Mutex -} - -func (h *ForwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *gossh.Request) (bool, []byte) { - h.Lock() - if h.forwards == nil { - h.forwards = make(map[string]net.Listener) - } - h.Unlock() - conn := ctx.Value(ContextKeyConn).(*gossh.ServerConn) - switch req.Type { - case "tcpip-forward": - var reqPayload remoteForwardRequest - if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil { - // TODO: log parse failure - return false, []byte{} - } - if srv.ReversePortForwardingCallback == nil || !srv.ReversePortForwardingCallback(ctx, reqPayload.BindAddr, reqPayload.BindPort) { - return false, []byte("port forwarding is disabled") - } - addr := net.JoinHostPort(reqPayload.BindAddr, strconv.Itoa(int(reqPayload.BindPort))) - ln, err := net.Listen("tcp", addr) - if err != nil { - // TODO: log listen failure - return false, []byte{} - } - _, destPortStr, _ := net.SplitHostPort(ln.Addr().String()) - destPort, _ := strconv.Atoi(destPortStr) - h.Lock() - h.forwards[addr] = ln - h.Unlock() - go func() { - <-ctx.Done() - h.Lock() - ln, ok := h.forwards[addr] - h.Unlock() - if ok { - ln.Close() - } - }() - go func() { - for { - c, err := ln.Accept() - if err != nil { - // TODO: log accept failure - break - } - originAddr, orignPortStr, _ := net.SplitHostPort(c.RemoteAddr().String()) - originPort, _ := strconv.Atoi(orignPortStr) - payload := gossh.Marshal(&remoteForwardChannelData{ - DestAddr: reqPayload.BindAddr, - DestPort: uint32(destPort), - OriginAddr: originAddr, - OriginPort: uint32(originPort), - }) - go func() { - ch, reqs, err := conn.OpenChannel(forwardedTCPChannelType, payload) - if err != nil { - // TODO: log failure to open channel - log.Println(err) - c.Close() - return - } - go gossh.DiscardRequests(reqs) - go func() { - defer ch.Close() - defer c.Close() - io.Copy(ch, c) - }() - go func() { - defer ch.Close() - defer c.Close() - io.Copy(c, ch) - }() - }() - } - h.Lock() - delete(h.forwards, addr) - h.Unlock() - }() - return true, gossh.Marshal(&remoteForwardSuccess{uint32(destPort)}) - - case "cancel-tcpip-forward": - var reqPayload remoteForwardCancelRequest - if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil { - // TODO: log parse failure - return false, []byte{} - } - addr := net.JoinHostPort(reqPayload.BindAddr, strconv.Itoa(int(reqPayload.BindPort))) - h.Lock() - ln, ok := h.forwards[addr] - h.Unlock() - if ok { - ln.Close() - } - return true, nil - default: - return false, nil - } -} diff --git a/tempfork/gliderlabs/ssh/tcpip_test.go b/tempfork/gliderlabs/ssh/tcpip_test.go deleted file mode 100644 index b3ba60a9bb6b8..0000000000000 --- a/tempfork/gliderlabs/ssh/tcpip_test.go +++ /dev/null @@ -1,85 +0,0 @@ -//go:build glidertests - -package ssh - -import ( - "bytes" - "io" - "net" - "strconv" - "strings" - "testing" - - gossh "golang.org/x/crypto/ssh" -) - -var sampleServerResponse = []byte("Hello world") - -func sampleSocketServer() net.Listener { - l := newLocalListener() - - go func() { - conn, err := l.Accept() - if err != nil { - return - } - conn.Write(sampleServerResponse) - conn.Close() - }() - - return l -} - -func newTestSessionWithForwarding(t *testing.T, forwardingEnabled bool) (net.Listener, *gossh.Client, func()) { - l := sampleSocketServer() - - _, client, cleanup := newTestSession(t, &Server{ - Handler: func(s Session) {}, - LocalPortForwardingCallback: func(ctx Context, destinationHost string, destinationPort uint32) bool { - addr := net.JoinHostPort(destinationHost, strconv.FormatInt(int64(destinationPort), 10)) - if addr != l.Addr().String() { - panic("unexpected destinationHost: " + addr) - } - return forwardingEnabled - }, - }, nil) - - return l, client, func() { - cleanup() - l.Close() - } -} - -func TestLocalPortForwardingWorks(t *testing.T) { - t.Parallel() - - l, client, cleanup := newTestSessionWithForwarding(t, true) - defer cleanup() - - conn, err := client.Dial("tcp", l.Addr().String()) - if err != nil { - t.Fatalf("Error connecting to %v: %v", l.Addr().String(), err) - } - result, err := io.ReadAll(conn) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(result, sampleServerResponse) { - t.Fatalf("result = %#v; want %#v", result, sampleServerResponse) - } -} - -func TestLocalPortForwardingRespectsCallback(t *testing.T) { - t.Parallel() - - l, client, cleanup := newTestSessionWithForwarding(t, false) - defer cleanup() - - _, err := client.Dial("tcp", l.Addr().String()) - if err == nil { - t.Fatalf("Expected error connecting to %v but it succeeded", l.Addr().String()) - } - if !strings.Contains(err.Error(), "port forwarding is disabled") { - t.Fatalf("Expected permission error but got %#v", err) - } -} diff --git a/tempfork/gliderlabs/ssh/util.go b/tempfork/gliderlabs/ssh/util.go deleted file mode 100644 index 3bee06dcdef39..0000000000000 --- a/tempfork/gliderlabs/ssh/util.go +++ /dev/null @@ -1,157 +0,0 @@ -package ssh - -import ( - "crypto/rand" - "crypto/rsa" - "encoding/binary" - - "golang.org/x/crypto/ssh" -) - -func generateSigner() (ssh.Signer, error) { - key, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return nil, err - } - return ssh.NewSignerFromKey(key) -} - -func parsePtyRequest(payload []byte) (pty Pty, ok bool) { - // See https://datatracker.ietf.org/doc/html/rfc4254#section-6.2 - // 6.2. Requesting a Pseudo-Terminal - // A pseudo-terminal can be allocated for the session by sending the - // following message. - // byte SSH_MSG_CHANNEL_REQUEST - // uint32 recipient channel - // string "pty-req" - // boolean want_reply - // string TERM environment variable value (e.g., vt100) - // uint32 terminal width, characters (e.g., 80) - // uint32 terminal height, rows (e.g., 24) - // uint32 terminal width, pixels (e.g., 640) - // uint32 terminal height, pixels (e.g., 480) - // string encoded terminal modes - - // The payload starts from the TERM variable. - term, rem, ok := parseString(payload) - if !ok { - return - } - win, rem, ok := parseWindow(rem) - if !ok { - return - } - modes, ok := parseTerminalModes(rem) - if !ok { - return - } - pty = Pty{ - Term: term, - Window: win, - Modes: modes, - } - return -} - -func parseTerminalModes(in []byte) (modes ssh.TerminalModes, ok bool) { - // See https://datatracker.ietf.org/doc/html/rfc4254#section-8 - // 8. Encoding of Terminal Modes - // - // All 'encoded terminal modes' (as passed in a pty request) are encoded - // into a byte stream. It is intended that the coding be portable - // across different environments. The stream consists of opcode- - // argument pairs wherein the opcode is a byte value. Opcodes 1 to 159 - // have a single uint32 argument. Opcodes 160 to 255 are not yet - // defined, and cause parsing to stop (they should only be used after - // any other data). The stream is terminated by opcode TTY_OP_END - // (0x00). - // - // The client SHOULD put any modes it knows about in the stream, and the - // server MAY ignore any modes it does not know about. This allows some - // degree of machine-independence, at least between systems that use a - // POSIX-like tty interface. The protocol can support other systems as - // well, but the client may need to fill reasonable values for a number - // of parameters so the server pty gets set to a reasonable mode (the - // server leaves all unspecified mode bits in their default values, and - // only some combinations make sense). - _, rem, ok := parseUint32(in) - if !ok { - return - } - const ttyOpEnd = 0 - for len(rem) > 0 { - if modes == nil { - modes = make(ssh.TerminalModes) - } - code := uint8(rem[0]) - rem = rem[1:] - if code == ttyOpEnd || code > 160 { - break - } - var val uint32 - val, rem, ok = parseUint32(rem) - if !ok { - return - } - modes[code] = val - } - ok = true - return -} - -func parseWindow(s []byte) (win Window, rem []byte, ok bool) { - // See https://datatracker.ietf.org/doc/html/rfc4254#section-6.7 - // 6.7. Window Dimension Change Message - // When the window (terminal) size changes on the client side, it MAY - // send a message to the other side to inform it of the new dimensions. - - // byte SSH_MSG_CHANNEL_REQUEST - // uint32 recipient channel - // string "window-change" - // boolean FALSE - // uint32 terminal width, columns - // uint32 terminal height, rows - // uint32 terminal width, pixels - // uint32 terminal height, pixels - wCols, rem, ok := parseUint32(s) - if !ok { - return - } - hRows, rem, ok := parseUint32(rem) - if !ok { - return - } - wPixels, rem, ok := parseUint32(rem) - if !ok { - return - } - hPixels, rem, ok := parseUint32(rem) - if !ok { - return - } - win = Window{ - Width: int(wCols), - Height: int(hRows), - WidthPixels: int(wPixels), - HeightPixels: int(hPixels), - } - return -} - -func parseString(in []byte) (out string, rem []byte, ok bool) { - length, rem, ok := parseUint32(in) - if uint32(len(rem)) < length || !ok { - ok = false - return - } - out, rem = string(rem[:length]), rem[length:] - ok = true - return -} - -func parseUint32(in []byte) (uint32, []byte, bool) { - if len(in) < 4 { - return 0, nil, false - } - return binary.BigEndian.Uint32(in), in[4:], true -} diff --git a/tempfork/gliderlabs/ssh/wrap.go b/tempfork/gliderlabs/ssh/wrap.go deleted file mode 100644 index d1f2b161e6932..0000000000000 --- a/tempfork/gliderlabs/ssh/wrap.go +++ /dev/null @@ -1,33 +0,0 @@ -package ssh - -import gossh "golang.org/x/crypto/ssh" - -// PublicKey is an abstraction of different types of public keys. -type PublicKey interface { - gossh.PublicKey -} - -// The Permissions type holds fine-grained permissions that are specific to a -// user or a specific authentication method for a user. Permissions, except for -// "source-address", must be enforced in the server application layer, after -// successful authentication. -type Permissions struct { - *gossh.Permissions -} - -// A Signer can create signatures that verify against a public key. -type Signer interface { - gossh.Signer -} - -// ParseAuthorizedKey parses a public key from an authorized_keys file used in -// OpenSSH according to the sshd(8) manual page. -func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []string, rest []byte, err error) { - return gossh.ParseAuthorizedKey(in) -} - -// ParsePublicKey parses an SSH public key formatted for use in -// the SSH wire protocol according to RFC 4253, section 6.6. -func ParsePublicKey(in []byte) (out PublicKey, err error) { - return gossh.ParsePublicKey(in) -} diff --git a/tempfork/pkgdoc/pkgdoc.go b/tempfork/pkgdoc/pkgdoc.go new file mode 100644 index 0000000000000..cab38dd48ec32 --- /dev/null +++ b/tempfork/pkgdoc/pkgdoc.go @@ -0,0 +1,234 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package pkgdoc is a library-ified fork of Go's cmd/doc program +// that only does what we need for misc/genreadme. +package pkgdoc + +import ( + "bytes" + "errors" + "fmt" + "go/ast" + "go/build" + "go/doc" + "go/doc/comment" + "go/parser" + "go/token" + "io" + "io/fs" + "log" + "slices" +) + +const ( + punchedCardWidth = 80 + indent = " " +) + +type Package struct { + writer io.Writer // Destination for output. + name string // Package name, json for encoding/json. + userPath string // String the user used to find this package. + pkg *ast.Package // Parsed package. + file *ast.File // Merged from all files in the package + doc *doc.Package + build *build.Package + fs *token.FileSet // Needed for printing. + buf pkgBuffer +} + +func (pkg *Package) ToText(w io.Writer, text, prefix, codePrefix string) { + d := pkg.doc.Parser().Parse(text) + pr := pkg.doc.Printer() + pr.TextPrefix = prefix + pr.TextCodePrefix = codePrefix + w.Write(pr.Text(d)) +} + +// ToMarkdown parses the godoc comment text and writes a Markdown rendering to w +// suitable for a repository README.md: top-level sections become ## headings +// without per-heading anchor IDs, and [Symbol] doc links resolve to pkg.go.dev, +// including for symbols in the current package (which the default printer would +// otherwise emit as bare #Name fragments with no backing anchor). +func (pkg *Package) ToMarkdown(w io.Writer, text string) { + d := pkg.doc.Parser().Parse(text) + pr := pkg.doc.Printer() + pr.HeadingLevel = 2 + pr.HeadingID = func(*comment.Heading) string { return "" } + pr.DocLinkBaseURL = "https://pkg.go.dev" + pr.DocLinkURL = func(link *comment.DocLink) string { + importPath := link.ImportPath + if importPath == "" { + importPath = pkg.doc.ImportPath + } + name := link.Name + if link.Recv != "" { + name = link.Recv + "." + name + } + return "https://pkg.go.dev/" + importPath + "#" + name + } + w.Write(pr.Markdown(d)) +} + +// pkgBuffer is a wrapper for bytes.Buffer that prints a package clause the +// first time Write is called. +type pkgBuffer struct { + pkg *Package + printed bool // Prevent repeated package clauses. + bytes.Buffer +} + +func (pb *pkgBuffer) Write(p []byte) (int, error) { + pb.packageClause() + return pb.Buffer.Write(p) +} + +func (pb *pkgBuffer) packageClause() { + if !pb.printed { + pb.printed = true + // Only show package clause for commands if requested explicitly. + if pb.pkg.pkg.Name != "main" { + pb.pkg.packageClause() + } + } +} + +type PackageError string // type returned by pkg.Fatalf. + +func (p PackageError) Error() string { + return string(p) +} + +// parsePackage turns the build package we found into a parsed package +// we can then use to generate documentation. +func parsePackage(writer io.Writer, pkg *build.Package, userPath string) *Package { + // include tells parser.ParseDir which files to include. + // That means the file must be in the build package's GoFiles or CgoFiles + // list only (no tag-ignored files, tests, swig or other non-Go files). + include := func(info fs.FileInfo) bool { + return slices.Contains(pkg.GoFiles, info.Name()) || slices.Contains(pkg.CgoFiles, info.Name()) + } + fset := token.NewFileSet() + // Parse declarations (not just imports) so that doc.Package knows the + // package's symbols; the Markdown printer needs this to resolve + // [Symbol] doc links in package comments. + pkgs, err := parser.ParseDir(fset, pkg.Dir, include, parser.ParseComments) + if err != nil { + log.Fatal(err) + } + // Make sure they are all in one package. + if len(pkgs) == 0 { + log.Fatalf("no source-code package in directory %s", pkg.Dir) + } + if len(pkgs) > 1 { + log.Fatalf("multiple packages in directory %s", pkg.Dir) + } + astPkg := pkgs[pkg.Name] + + // TODO: go/doc does not include typed constants in the constants + // list, which is what we want. For instance, time.Sunday is of type + // time.Weekday, so it is defined in the type but not in the + // Consts list for the package. This prevents + // go doc time.Sunday + // from finding the symbol. Work around this for now, but we + // should fix it in go/doc. + // A similar story applies to factory functions. + mode := doc.AllDecls + docPkg := doc.New(astPkg, pkg.ImportPath, mode) + + p := &Package{ + writer: writer, + name: pkg.Name, + userPath: userPath, + pkg: astPkg, + file: ast.MergePackageFiles(astPkg, 0), + doc: docPkg, + build: pkg, + fs: fset, + } + p.buf.pkg = p + return p +} + +func (pkg *Package) Printf(format string, args ...any) { + fmt.Fprintf(&pkg.buf, format, args...) +} + +func (pkg *Package) flush() { + _, err := pkg.writer.Write(pkg.buf.Bytes()) + if err != nil { + log.Fatal(err) + } + pkg.buf.Reset() // Not needed, but it's a flush. +} + +var newlineBytes = []byte("\n\n") // We never ask for more than 2. + +// newlines guarantees there are n newlines at the end of the buffer. +func (pkg *Package) newlines(n int) { + for !bytes.HasSuffix(pkg.buf.Bytes(), newlineBytes[:n]) { + pkg.buf.WriteRune('\n') + } +} + +// packageDoc prints the docs for the package as Markdown. +func (pkg *Package) packageDoc() { + pkg.Printf("") // Trigger the package clause; we know the package exists. + pkg.ToMarkdown(&pkg.buf, pkg.doc.Doc) + pkg.newlines(1) + + pkg.bugs() +} + +// packageClause prints the package clause. +func (pkg *Package) packageClause() { + importPath := pkg.build.ImportComment + if importPath == "" { + importPath = pkg.build.ImportPath + } + + pkg.Printf("package %s // import %q\n\n", pkg.name, importPath) +} + +// bugs prints the BUGS information for the package. +// TODO: Provide access to TODOs and NOTEs as well (very noisy so off by default)? +func (pkg *Package) bugs() { + if pkg.doc.Notes["BUG"] == nil { + return + } + pkg.Printf("\n") + for _, note := range pkg.doc.Notes["BUG"] { + pkg.Printf("%s: %v\n", "BUG", note.Body) + } +} + +// PackageDoc generates Markdown documentation for the package in the given +// directory. importPath is the full Go import path of that package (e.g. +// "tailscale.com/tsnet"); it's used to render [Symbol] doc links to the +// right pkg.go.dev URL. If importPath is empty, build.ImportDir's guess +// is used (typically "." for module-based repos). +func PackageDoc(dir, importPath string) ([]byte, error) { + var buf bytes.Buffer + var writer io.Writer = &buf + + buildPackage, err := build.ImportDir(dir, build.ImportComment) + if err != nil { + var noGoError *build.NoGoError + if errors.As(err, &noGoError) { + return nil, nil + } + return nil, err + } + if importPath != "" { + buildPackage.ImportPath = importPath + } + userPath := dir + + pkg := parsePackage(writer, buildPackage, userPath) + pkg.packageDoc() + pkg.flush() + + return buf.Bytes(), nil +} diff --git a/tka/aum_test.go b/tka/aum_test.go index 4f32e91a1964f..78966c76690ba 100644 --- a/tka/aum_test.go +++ b/tka/aum_test.go @@ -104,7 +104,7 @@ func TestSerialization(t *testing.T) { }, bytes.Repeat([]byte{0}, 32)...), []byte{ - 0x02, // |- major type 0 (int), value 2 (second key, DisablementSecrets) + 0x02, // |- major type 0 (int), value 2 (second key, DisablementValues) 0xf6, // |- major type 7 (val), value null (second value, nil) 0x03, // |- major type 0 (int), value 3 (third key, Keys) 0x81, // |- major type 4 (array), value 1 (one item in array) @@ -182,7 +182,7 @@ func TestDeserializeExistingAUMs(t *testing.T) { Want: AUM{ MessageKind: AUMCheckpoint, State: &State{ - DisablementSecrets: [][]byte{ + DisablementValues: [][]byte{ fromBase64("jSwtotIRlTdbkNPV0bZZifOMIGvi1e1VsJPYu8D0tLo="), fromBase64("EIcFRg4lBkYrtz+t4LnGf/KLY7dg18pPjgY24eYlsdQ="), fromBase64("5VU4oRQiMoq5qK00McfpwtmjcheVammLCRwzdp2Zje8="), diff --git a/tka/builder.go b/tka/builder.go index 1e7b130151876..131d54dde1318 100644 --- a/tka/builder.go +++ b/tka/builder.go @@ -67,11 +67,11 @@ func (b *UpdateBuilder) AddKey(key Key) error { } if _, err := b.state.GetKey(keyID); err == nil { - return fmt.Errorf("cannot add key %v: already exists", key) + return fmt.Errorf("cannot add key tlpub:%x: already exists", key.Public) } if len(b.state.Keys) >= maxKeys { - return fmt.Errorf("cannot add key %v: maximum number of keys reached", key) + return fmt.Errorf("cannot add key tlpub:%x: maximum number of keys reached", key.Public) } return b.mkUpdate(AUM{MessageKind: AUMAddKey, Key: &key}) diff --git a/tka/builder_test.go b/tka/builder_test.go index edca1e95a516e..29ecaf88c0382 100644 --- a/tka/builder_test.go +++ b/tka/builder_test.go @@ -30,8 +30,8 @@ func TestAuthorityBuilderAddKey(t *testing.T) { storage := ChonkMem() a, _, err := Create(storage, State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + Keys: []Key{key}, + DisablementValues: [][]byte{DisablementKDF([]byte{1, 2, 3})}, }, signer25519(priv)) if err != nil { t.Fatalf("Create() failed: %v", err) @@ -64,8 +64,8 @@ func TestAuthorityBuilderMaxKey(t *testing.T) { storage := ChonkMem() a, _, err := Create(storage, State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + Keys: []Key{key}, + DisablementValues: [][]byte{DisablementKDF([]byte{1, 2, 3})}, }, signer25519(priv)) if err != nil { t.Fatalf("Create() failed: %v", err) @@ -111,8 +111,8 @@ func TestAuthorityBuilderRemoveKey(t *testing.T) { storage := ChonkMem() a, _, err := Create(storage, State{ - Keys: []Key{key, key2}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + Keys: []Key{key, key2}, + DisablementValues: [][]byte{DisablementKDF([]byte{1, 2, 3})}, }, signer25519(priv)) if err != nil { t.Fatalf("Create() failed: %v", err) @@ -157,8 +157,8 @@ func TestAuthorityBuilderSetKeyVote(t *testing.T) { storage := ChonkMem() a, _, err := Create(storage, State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + Keys: []Key{key}, + DisablementValues: [][]byte{DisablementKDF([]byte{1, 2, 3})}, }, signer25519(priv)) if err != nil { t.Fatalf("Create() failed: %v", err) @@ -193,8 +193,8 @@ func TestAuthorityBuilderSetKeyMeta(t *testing.T) { storage := ChonkMem() a, _, err := Create(storage, State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + Keys: []Key{key}, + DisablementValues: [][]byte{DisablementKDF([]byte{1, 2, 3})}, }, signer25519(priv)) if err != nil { t.Fatalf("Create() failed: %v", err) @@ -229,8 +229,8 @@ func TestAuthorityBuilderMultiple(t *testing.T) { storage := ChonkMem() a, _, err := Create(storage, State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + Keys: []Key{key}, + DisablementValues: [][]byte{DisablementKDF([]byte{1, 2, 3})}, }, signer25519(priv)) if err != nil { t.Fatalf("Create() failed: %v", err) @@ -277,8 +277,8 @@ func TestAuthorityBuilderCheckpointsAfterXUpdates(t *testing.T) { storage := ChonkMem() a, _, err := Create(storage, State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + Keys: []Key{key}, + DisablementValues: [][]byte{DisablementKDF([]byte{1, 2, 3})}, }, signer25519(priv)) if err != nil { t.Fatalf("Create() failed: %v", err) diff --git a/tka/chaintest_test.go b/tka/chaintest_test.go index c370bf60a2e4c..0ec612210b722 100644 --- a/tka/chaintest_test.go +++ b/tka/chaintest_test.go @@ -7,6 +7,7 @@ import ( "bytes" "crypto/ed25519" "fmt" + "maps" "strconv" "strings" "testing" @@ -198,14 +199,12 @@ func (c *testChain) recordParent(t *testing.T, child, parent string) { // This method populates c.AUMs and c.AUMHashes. func (c *testChain) buildChain() { pending := make(map[string]*testchainNode, len(c.Nodes)) - for k, v := range c.Nodes { - pending[k] = v - } + maps.Copy(pending, c.Nodes) // AUMs with a parent need to know their hash, so we - // only compute AUMs who's parents have been computed + // only compute AUMs whose parents have been computed // each iteration. Since at least the genesis AUM - // had no parent, theres always a path to completion + // had no parent, there's always a path to completion // in O(n+1) where n is the number of AUMs. c.AUMs = make(map[string]AUM, len(c.Nodes)) c.AUMHashes = make(map[string]AUMHash, len(c.Nodes)) @@ -321,6 +320,22 @@ func optTemplate(name string, template AUM) testchainOpt { } } +func genesisTemplate(key Key) testchainOpt { + return optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ + Keys: []Key{key}, + DisablementValues: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }}) +} + +func checkpointTemplate() testchainOpt { + fakeState := &State{ + Keys: []Key{{Kind: Key25519, Votes: 1}}, + DisablementValues: [][]byte{bytes.Repeat([]byte{1}, 32)}, + } + + return optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState}) +} + func optKey(name string, key Key, priv ed25519.PrivateKey) testchainOpt { return testchainOpt{ Name: name, diff --git a/tka/deeplink_test.go b/tka/deeplink_test.go index 6d85b158589ac..260ec9026ede9 100644 --- a/tka/deeplink_test.go +++ b/tka/deeplink_test.go @@ -14,11 +14,7 @@ func TestGenerateDeeplink(t *testing.T) { G1 -> L1 G1.template = genesis - `, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), + `, genesisTemplate(key), ) a, _ := Open(c.Chonk()) diff --git a/tka/disabled_stub.go b/tka/disabled_stub.go index d14473e5ec1ac..f3cabd491dd12 100644 --- a/tka/disabled_stub.go +++ b/tka/disabled_stub.go @@ -8,6 +8,7 @@ package tka import ( "crypto/ed25519" "errors" + "time" "tailscale.com/types/key" "tailscale.com/types/logger" @@ -158,3 +159,8 @@ func SignByCredential(privKey []byte, wrapped *NodeKeySignature, nodeKey key.Nod } func (s NodeKeySignature) String() string { return "" } + +type CompactionOptions struct { + MinChain int + MinAge time.Duration +} diff --git a/tka/key.go b/tka/key.go index bc946156eb9be..08897d4095889 100644 --- a/tka/key.go +++ b/tka/key.go @@ -7,6 +7,7 @@ import ( "crypto/ed25519" "errors" "fmt" + "maps" "tailscale.com/types/tkatype" ) @@ -64,9 +65,7 @@ func (k Key) Clone() Key { if k.Meta != nil { out.Meta = make(map[string]string, len(k.Meta)) - for k, v := range k.Meta { - out.Meta[k] = v - } + maps.Copy(out.Meta, k.Meta) } return out @@ -105,8 +104,6 @@ func (k Key) Ed25519() (ed25519.PublicKey, error) { } } -const maxMetaBytes = 512 - func (k Key) StaticValidate() error { if k.Votes > 4096 { return fmt.Errorf("excessive key weight: %d > 4096", k.Votes) diff --git a/tka/key_test.go b/tka/key_test.go index 799accc857e1c..cc6a1f580c013 100644 --- a/tka/key_test.go +++ b/tka/key_test.go @@ -73,8 +73,8 @@ func TestNLPrivate(t *testing.T) { // authority. k := Key{Kind: Key25519, Public: pub.Verifier(), Votes: 1} _, aum, err := Create(ChonkMem(), State{ - Keys: []Key{k}, - DisablementSecrets: [][]byte{bytes.Repeat([]byte{1}, 32)}, + Keys: []Key{k}, + DisablementValues: [][]byte{bytes.Repeat([]byte{1}, 32)}, }, p) if err != nil { t.Fatalf("Create() failed: %v", err) diff --git a/tka/limits.go b/tka/limits.go new file mode 100644 index 0000000000000..11f53654f2e8c --- /dev/null +++ b/tka/limits.go @@ -0,0 +1,35 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package tka + +import ( + "time" +) + +const ( + // Upper bound on checkpoint elements, chosen arbitrarily. Intended + // to cap the size of large AUMs. + maxDisablementValues = 32 + maxKeys = 512 + + // Max amount of metadata that can be associated with a key, chosen arbitrarily. + // Intended to avoid people abusing TKA as a key-value score. + maxMetaBytes = 512 + + // Max iterations searching for any intersection during the sync process. + maxSyncIter = 2000 + + // Max iterations searching for a head intersection during the sync process. + maxSyncHeadIntersectionIter = 400 + + // Limit on scanning AUM trees, chosen arbitrarily. + maxScanIterations = 2000 +) + +var ( + CompactionDefaults = CompactionOptions{ + MinChain: 24, // Keep at minimum 24 AUMs since head. + MinAge: 14 * 24 * time.Hour, // Keep 2 weeks of AUMs. + } +) diff --git a/tka/scenario_test.go b/tka/scenario_test.go index cf4ee2d5b2582..61d9e25290ed8 100644 --- a/tka/scenario_test.go +++ b/tka/scenario_test.go @@ -5,6 +5,7 @@ package tka import ( "crypto/ed25519" + "maps" "sort" "testing" ) @@ -36,9 +37,7 @@ func (s *scenarioTest) mkNode(name string) *scenarioNode { } aums := make(map[string]AUM, len(s.initial.AUMs)) - for k, v := range s.initial.AUMs { - aums[k] = v - } + maps.Copy(aums, s.initial.AUMs) n := &scenarioNode{ A: authority, @@ -148,10 +147,7 @@ func testScenario(t *testing.T, sharedChain string, sharedOptions ...testchainOp pub, priv := testingKey25519(t, 1) key := Key{Kind: Key25519, Public: pub, Votes: 1} sharedOptions = append(sharedOptions, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), + genesisTemplate(key), optKey("key", key, priv), optSignAllUsing("key")) diff --git a/tka/sig_test.go b/tka/sig_test.go index efec62b7d791f..4581d4cc3ce9b 100644 --- a/tka/sig_test.go +++ b/tka/sig_test.go @@ -173,11 +173,8 @@ func TestSigNested_DeepNesting(t *testing.T) { } // Test this works with our public API - a, _ := Open(newTestchain(t, "G1\nG1.template = genesis", - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{k}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }})).Chonk()) + c := newTestchain(t, "G1\nG1.template = genesis", genesisTemplate(k)) + a, _ := Open(c.Chonk()) if err := a.NodeKeyAuthorized(lastNodeKey.Public(), outer.Serialize()); err != nil { t.Errorf("NodeKeyAuthorized(lastNodeKey) failed: %v", err) } @@ -238,11 +235,8 @@ func TestSigCredential(t *testing.T) { } // Test someone can't misuse our public API for verifying node-keys - a, _ := Open(newTestchain(t, "G1\nG1.template = genesis", - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{k}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }})).Chonk()) + c := newTestchain(t, "G1\nG1.template = genesis", genesisTemplate(k)) + a, _ := Open(c.Chonk()) if err := a.NodeKeyAuthorized(node.Public(), nestedSig.Serialize()); err == nil { t.Error("NodeKeyAuthorized(SigCredential, node) did not fail") } diff --git a/tka/state.go b/tka/state.go index 06fdc65048b59..dc10f0553e9cd 100644 --- a/tka/state.go +++ b/tka/state.go @@ -7,6 +7,7 @@ package tka import ( "bytes" + "crypto/subtle" "errors" "fmt" @@ -28,9 +29,13 @@ type State struct { // is the same as the LastAUMHash. LastAUMHash *AUMHash `cbor:"1,keyasint"` - // DisablementSecrets are KDF-derived values which can be used - // to turn off the TKA in the event of a consensus-breaking bug. - DisablementSecrets [][]byte `cbor:"2,keyasint"` + // DisablementValues are KDF-derived values used to verify that a caller + // possesses a valid DisablementSecret. These values are used during the + // Tailnet Lock deactivation process. + // + // These are safe to share publicly or store in the clear. They cannot be + // used to derive the original DisablementSecret. + DisablementValues [][]byte `cbor:"2,keyasint"` // Keys are the public keys of either: // @@ -78,11 +83,11 @@ func (s State) Clone() State { out.LastAUMHash = &dupe } - if s.DisablementSecrets != nil { - out.DisablementSecrets = make([][]byte, len(s.DisablementSecrets)) - for i := range s.DisablementSecrets { - out.DisablementSecrets[i] = make([]byte, len(s.DisablementSecrets[i])) - copy(out.DisablementSecrets[i], s.DisablementSecrets[i]) + if s.DisablementValues != nil { + out.DisablementValues = make([][]byte, len(s.DisablementValues)) + for i := range s.DisablementValues { + out.DisablementValues[i] = make([]byte, len(s.DisablementValues[i])) + copy(out.DisablementValues[i], s.DisablementValues[i]) } } @@ -113,7 +118,7 @@ var disablementSalt = []byte("tailscale network-lock disablement salt") // key authority, but cannot be reversed to find the input secret. // // When the output of this function is stored in tka state (i.e. in -// tka.State.DisablementSecrets) a call to Authority.ValidDisablement() +// tka.State.DisablementValues) a call to Authority.ValidDisablement() // with the input of this function as the argument will return true. func DisablementKDF(secret []byte) []byte { // time = 4 (3 recommended, booped to 4 to compensate for less memory) @@ -126,8 +131,8 @@ func DisablementKDF(secret []byte) []byte { // checkDisablement returns true for a valid disablement secret. func (s State) checkDisablement(secret []byte) bool { derived := DisablementKDF(secret) - for _, candidate := range s.DisablementSecrets { - if bytes.Equal(derived, candidate) { + for _, candidate := range s.DisablementValues { + if subtle.ConstantTimeCompare(derived, candidate) == 1 { return true } } @@ -247,30 +252,23 @@ func (s State) applyVerifiedAUM(update AUM) (State, error) { } } -// Upper bound on checkpoint elements, chosen arbitrarily. Intended to -// cap out insanely large AUMs. -const ( - maxDisablementSecrets = 32 - maxKeys = 512 -) - // staticValidateCheckpoint validates that the state is well-formed for // inclusion in a checkpoint AUM. func (s *State) staticValidateCheckpoint() error { if s.LastAUMHash != nil { return errors.New("cannot specify a parent AUM") } - if len(s.DisablementSecrets) == 0 { + if len(s.DisablementValues) == 0 { return errors.New("at least one disablement secret required") } - if numDS := len(s.DisablementSecrets); numDS > maxDisablementSecrets { - return fmt.Errorf("too many disablement secrets (%d, max %d)", numDS, maxDisablementSecrets) + if numDS := len(s.DisablementValues); numDS > maxDisablementValues { + return fmt.Errorf("too many disablement values (%d, max %d)", numDS, maxDisablementValues) } - for i, ds := range s.DisablementSecrets { + for i, ds := range s.DisablementValues { if len(ds) != disablementLength { return fmt.Errorf("disablement[%d]: invalid length (got %d, want %d)", i, len(ds), disablementLength) } - for j, ds2 := range s.DisablementSecrets { + for j, ds2 := range s.DisablementValues { if i == j { continue } diff --git a/tka/state_test.go b/tka/state_test.go index 337e3c3ceff85..e5208e4e6e71d 100644 --- a/tka/state_test.go +++ b/tka/state_test.go @@ -36,26 +36,26 @@ func TestCloneState(t *testing.T) { State State }{ { - "Empty", - State{}, + Name: "Empty", + State: State{}, }, { - "Key", - State{ + Name: "Key", + State: State{ Keys: []Key{{Kind: Key25519, Votes: 2, Public: []byte{5, 6, 7, 8}, Meta: map[string]string{"a": "b"}}}, }, }, { - "StateID", - State{ + Name: "StateID", + State: State{ StateID1: 42, StateID2: 22, }, }, { - "DisablementSecrets", - State{ - DisablementSecrets: [][]byte{ + Name: "DisablementValues", + State: State{ + DisablementValues: [][]byte{ {1, 2, 3, 4}, {5, 6, 7, 8}, }, @@ -155,7 +155,7 @@ func TestApplyUpdatesChain(t *testing.T) { Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, }, PrevAUMHash: fromHex("f09bda3bb7cf6756ea9adc25770aede4b3ca8142949d6ef5ca0add29af912fd4")}, }, - State{DisablementSecrets: [][]byte{{1, 2, 3, 4}}}, + State{DisablementValues: [][]byte{{1, 2, 3, 4}}}, State{ Keys: []Key{{Kind: Key25519, Public: []byte{1, 2, 3, 4}}}, LastAUMHash: hashFromHex("57343671da5eea3cfb502954e976e8028bffd3540b50a043b2a65a8d8d8217d0"), diff --git a/tka/sync.go b/tka/sync.go index 27e1c0e633329..5cae9b45f9694 100644 --- a/tka/sync.go +++ b/tka/sync.go @@ -11,13 +11,6 @@ import ( "os" ) -const ( - // Max iterations searching for any intersection. - maxSyncIter = 2000 - // Max iterations searching for a head intersection. - maxSyncHeadIntersectionIter = 400 -) - // ErrNoIntersection is returned when a shared AUM could // not be determined when evaluating a remote sync offer. var ErrNoIntersection = errors.New("no intersection") @@ -107,7 +100,7 @@ func (a *Authority) SyncOffer(storage Chonk) (SyncOffer, error) { skipAmount uint64 = ancestorsSkipStart curs AUMHash = a.Head() ) - for i := uint64(0); i < maxSyncHeadIntersectionIter; i++ { + for i := range uint64(maxSyncHeadIntersectionIter) { if i > 0 && (i%skipAmount) == 0 { out.Ancestors = append(out.Ancestors, curs) skipAmount = skipAmount << ancestorsSkipShift diff --git a/tka/sync_test.go b/tka/sync_test.go index 158f73c46cb01..48f197e8c3a19 100644 --- a/tka/sync_test.go +++ b/tka/sync_test.go @@ -11,22 +11,30 @@ import ( "github.com/google/go-cmp/cmp" ) -func TestSyncOffer(t *testing.T) { - c := newTestchain(t, ` - A1 -> A2 -> A3 -> A4 -> A5 -> A6 -> A7 -> A8 -> A9 -> A10 - A10 -> A11 -> A12 -> A13 -> A14 -> A15 -> A16 -> A17 -> A18 - A18 -> A19 -> A20 -> A21 -> A22 -> A23 -> A24 -> A25 - `) - storage := c.Chonk() +// getSyncOffer returns a SyncOffer for the given Chonk. +func getSyncOffer(t *testing.T, storage Chonk) SyncOffer { + t.Helper() + a, err := Open(storage) if err != nil { t.Fatal(err) } - got, err := a.SyncOffer(storage) + offer, err := a.SyncOffer(storage) if err != nil { t.Fatal(err) } + return offer +} + +func TestSyncOffer(t *testing.T) { + c := newTestchain(t, ` + A1 -> A2 -> A3 -> A4 -> A5 -> A6 -> A7 -> A8 -> A9 -> A10 + A10 -> A11 -> A12 -> A13 -> A14 -> A15 -> A16 -> A17 -> A18 + A18 -> A19 -> A20 -> A21 -> A22 -> A23 -> A24 -> A25 + `) + got := getSyncOffer(t, c.Chonk()) + // A SyncOffer includes a selection of AUMs going backwards in the tree, // progressively skipping more and more each iteration. want := SyncOffer{ @@ -52,24 +60,10 @@ func TestComputeSyncIntersection_FastForward(t *testing.T) { a1H, a2H := c.AUMHashes["A1"], c.AUMHashes["A2"] chonk1 := c.ChonkWith("A1", "A2") - n1, err := Open(chonk1) - if err != nil { - t.Fatal(err) - } - offer1, err := n1.SyncOffer(chonk1) - if err != nil { - t.Fatal(err) - } + offer1 := getSyncOffer(t, chonk1) chonk2 := c.Chonk() // All AUMs - n2, err := Open(chonk2) - if err != nil { - t.Fatal(err) - } - offer2, err := n2.SyncOffer(chonk2) - if err != nil { - t.Fatal(err) - } + offer2 := getSyncOffer(t, chonk2) // Node 1 only knows about the first two nodes, so the head of n2 is // alien to it. @@ -123,40 +117,28 @@ func TestComputeSyncIntersection_ForkSmallDiff(t *testing.T) { } chonk1 := c.ChonkWith("A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "F1") - n1, err := Open(chonk1) - if err != nil { - t.Fatal(err) - } - offer1, err := n1.SyncOffer(chonk1) - if err != nil { - t.Fatal(err) - } - if diff := cmp.Diff(SyncOffer{ + offer1 := getSyncOffer(t, chonk1) + want1 := SyncOffer{ Head: c.AUMHashes["F1"], Ancestors: []AUMHash{ c.AUMHashes["A"+strconv.Itoa(9-ancestorsSkipStart)], c.AUMHashes["A1"], }, - }, offer1); diff != "" { + } + if diff := cmp.Diff(want1, offer1); diff != "" { t.Errorf("offer1 diff (-want, +got):\n%s", diff) } chonk2 := c.ChonkWith("A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "A9", "A10") - n2, err := Open(chonk2) - if err != nil { - t.Fatal(err) - } - offer2, err := n2.SyncOffer(chonk2) - if err != nil { - t.Fatal(err) - } - if diff := cmp.Diff(SyncOffer{ + offer2 := getSyncOffer(t, chonk2) + want2 := SyncOffer{ Head: c.AUMHashes["A10"], Ancestors: []AUMHash{ c.AUMHashes["A"+strconv.Itoa(10-ancestorsSkipStart)], c.AUMHashes["A1"], }, - }, offer2); diff != "" { + } + if diff := cmp.Diff(want2, offer2); diff != "" { t.Errorf("offer2 diff (-want, +got):\n%s", diff) } @@ -339,10 +321,7 @@ func TestSyncSimpleE2E(t *testing.T) { G1 -> L1 -> L2 -> L3 G1.template = genesis `, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), + genesisTemplate(key), optKey("key", key, priv), optSignAllUsing("key")) diff --git a/tka/tailchonk.go b/tka/tailchonk.go index 256faaea2b8b9..3b083f327f3e7 100644 --- a/tka/tailchonk.go +++ b/tka/tailchonk.go @@ -715,7 +715,7 @@ func markActiveChain(storage Chonk, verdict map[AUMHash]retainState, minChain in parent, hasParent := next.Parent() if !hasParent { - // Genesis AUM (beginning of time). The chain isnt long enough to need truncating. + // Genesis AUM (beginning of time). The chain isn't long enough to need truncating. return h, nil } diff --git a/tka/tailchonk_test.go b/tka/tailchonk_test.go index d40e4b09da769..23bf45e20c5e4 100644 --- a/tka/tailchonk_test.go +++ b/tka/tailchonk_test.go @@ -185,7 +185,7 @@ func TestMarkActiveChain(t *testing.T) { expectLastActiveIdx: 0, }, { - name: "simple truncate", + name: "simple-truncate", minChain: 2, chain: []aumTemplate{ {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, @@ -196,7 +196,7 @@ func TestMarkActiveChain(t *testing.T) { expectLastActiveIdx: 1, }, { - name: "long truncate", + name: "long-truncate", minChain: 5, chain: []aumTemplate{ {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, @@ -211,7 +211,7 @@ func TestMarkActiveChain(t *testing.T) { expectLastActiveIdx: 2, }, { - name: "truncate finding checkpoint", + name: "truncate-finding-checkpoint", minChain: 2, chain: []aumTemplate{ {AUM: AUM{MessageKind: AUMCheckpoint, State: &State{}}}, @@ -309,17 +309,12 @@ func TestMarkDescendantAUMs(t *testing.T) { } for _, h := range []AUMHash{hs["genesis"], hs["B"], hs["D"]} { if (verdict[h] & retainStateLeaf) != 0 { - t.Errorf("%v was marked as a descendant and shouldnt be", h) + t.Errorf("%v was marked as a descendant and shouldn't be", h) } } } func TestMarkAncestorIntersectionAUMs(t *testing.T) { - fakeState := &State{ - Keys: []Key{{Kind: Key25519, Votes: 1}}, - DisablementSecrets: [][]byte{bytes.Repeat([]byte{1}, 32)}, - } - tcs := []struct { name string chain *testChain @@ -333,7 +328,7 @@ func TestMarkAncestorIntersectionAUMs(t *testing.T) { name: "genesis", chain: newTestchain(t, ` A - A.template = checkpoint`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), + A.template = checkpoint`, checkpointTemplate()), initialAncestor: "A", wantAncestor: "A", verdicts: map[string]retainState{ @@ -342,11 +337,11 @@ func TestMarkAncestorIntersectionAUMs(t *testing.T) { wantRetained: []string{"A"}, }, { - name: "no adjustment", + name: "no-adjustment", chain: newTestchain(t, ` DEAD -> A -> B -> C A.template = checkpoint - B.template = checkpoint`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), + B.template = checkpoint`, checkpointTemplate()), initialAncestor: "A", wantAncestor: "A", verdicts: map[string]retainState{ @@ -366,7 +361,7 @@ func TestMarkAncestorIntersectionAUMs(t *testing.T) { A.template = checkpoint C.template = checkpoint D.template = checkpoint - FORK.hashSeed = 2`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), + FORK.hashSeed = 2`, checkpointTemplate()), initialAncestor: "D", wantAncestor: "C", verdicts: map[string]retainState{ @@ -380,14 +375,14 @@ func TestMarkAncestorIntersectionAUMs(t *testing.T) { wantDeleted: []string{"A", "B"}, }, { - name: "fork finding earlier checkpoint", + name: "fork-finding-earlier-checkpoint", chain: newTestchain(t, ` A -> B -> C -> D -> E -> F | -> FORK A.template = checkpoint B.template = checkpoint E.template = checkpoint - FORK.hashSeed = 2`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), + FORK.hashSeed = 2`, checkpointTemplate()), initialAncestor: "E", wantAncestor: "B", verdicts: map[string]retainState{ @@ -403,7 +398,7 @@ func TestMarkAncestorIntersectionAUMs(t *testing.T) { wantDeleted: []string{"A"}, }, { - name: "fork multi", + name: "fork-multi", chain: newTestchain(t, ` A -> B -> C -> D -> E | -> DEADFORK @@ -413,7 +408,7 @@ func TestMarkAncestorIntersectionAUMs(t *testing.T) { D.template = checkpoint E.template = checkpoint FORK.hashSeed = 2 - DEADFORK.hashSeed = 3`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), + DEADFORK.hashSeed = 3`, checkpointTemplate()), initialAncestor: "D", wantAncestor: "C", verdicts: map[string]retainState{ @@ -429,7 +424,7 @@ func TestMarkAncestorIntersectionAUMs(t *testing.T) { wantDeleted: []string{"A", "B", "DEADFORK"}, }, { - name: "fork multi 2", + name: "fork-multi-2", chain: newTestchain(t, ` A -> B -> C -> D -> E -> F -> G @@ -443,7 +438,7 @@ func TestMarkAncestorIntersectionAUMs(t *testing.T) { F.template = checkpoint F1.hashSeed = 2 F2.hashSeed = 3 - F3.hashSeed = 4`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), + F3.hashSeed = 4`, checkpointTemplate()), initialAncestor: "F", wantAncestor: "B", verdicts: map[string]retainState{ @@ -541,11 +536,6 @@ func cloneMem(src, dst *Mem) { } func TestCompact(t *testing.T) { - fakeState := &State{ - Keys: []Key{{Kind: Key25519, Votes: 1}}, - DisablementSecrets: [][]byte{bytes.Repeat([]byte{1}, 32)}, - } - // A & B are deleted because the new lastActiveAncestor advances beyond them. // OLD is deleted because it does not match retention criteria, and // though it is a descendant of the new lastActiveAncestor (C), it is not a @@ -578,7 +568,7 @@ func TestCompact(t *testing.T) { F1.hashSeed = 1 OLD.hashSeed = 2 G2.hashSeed = 3 - `, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})) + `, checkpointTemplate()) storage := &compactingChonkFake{ aumAge: map[AUMHash]time.Time{(c.AUMHashes["F1"]): time.Now()}, @@ -610,8 +600,8 @@ func TestCompactLongButYoung(t *testing.T) { storage := ChonkMem() auth, _, err := Create(storage, State{ - Keys: []Key{ourKey, someOtherKey}, - DisablementSecrets: [][]byte{DisablementKDF(bytes.Repeat([]byte{0xa5}, 32))}, + Keys: []Key{ourKey, someOtherKey}, + DisablementValues: [][]byte{DisablementKDF(bytes.Repeat([]byte{0xa5}, 32))}, }, ourPriv) if err != nil { t.Fatalf("tka.Create() failed: %v", err) diff --git a/tka/tka.go b/tka/tka.go index e3862c29d3264..9b22edc2eb505 100644 --- a/tka/tka.go +++ b/tka/tka.go @@ -31,9 +31,6 @@ var cborDecOpts = cbor.DecOptions{ MaxMapPairs: 1024, } -// Arbitrarily chosen limit on scanning AUM trees. -const maxScanIterations = 2000 - // Authority is a Tailnet Key Authority. This type is the main coupling // point to the rest of the tailscale client. // diff --git a/tka/tka_test.go b/tka/tka_test.go index f2ce73d357343..4bd0ac0839584 100644 --- a/tka/tka_test.go +++ b/tka/tka_test.go @@ -304,10 +304,7 @@ func TestAuthorityValidDisablement(t *testing.T) { G1.template = genesis `, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), + genesisTemplate(key), ) a, _ := Open(c.Chonk()) @@ -321,8 +318,8 @@ func TestCreateBootstrapAuthority(t *testing.T) { key := Key{Kind: Key25519, Public: pub, Votes: 2} a1, genesisAUM, err := Create(ChonkMem(), State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + Keys: []Key{key}, + DisablementValues: [][]byte{DisablementKDF([]byte{1, 2, 3})}, }, signer25519(priv)) if err != nil { t.Fatalf("Create() failed: %v", err) @@ -353,8 +350,8 @@ func TestBootstrapChonkMustBeEmpty(t *testing.T) { pub, priv := testingKey25519(t, 1) key := Key{Kind: Key25519, Public: pub, Votes: 2} state := State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + Keys: []Key{key}, + DisablementValues: [][]byte{DisablementKDF([]byte{1, 2, 3})}, } // Bootstrap our chonk for the first time, which should succeed. @@ -419,10 +416,7 @@ func TestAuthorityInformNonLinear(t *testing.T) { L2.hashSeed = 2 L4.hashSeed = 2 `, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), + genesisTemplate(key), optKey("key", key, priv), optSignAllUsing("key")) @@ -464,10 +458,7 @@ func TestAuthorityInformLinear(t *testing.T) { G1.template = genesis `, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), + genesisTemplate(key), optKey("key", key, priv), optSignAllUsing("key")) @@ -517,7 +508,7 @@ func TestInteropWithNLKey(t *testing.T) { Public: pub2.KeyID(), }, }, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + DisablementValues: [][]byte{DisablementKDF([]byte{1, 2, 3})}, }, priv1) if err != nil { t.Errorf("tka.Create: %v", err) @@ -545,13 +536,10 @@ func TestAuthorityCompact(t *testing.T) { G.template = genesis C.template = checkpoint2 `, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), + genesisTemplate(key), optTemplate("checkpoint2", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{key}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + Keys: []Key{key}, + DisablementValues: [][]byte{DisablementKDF([]byte{1, 2, 3})}, }}), optKey("key", key, priv), optSignAllUsing("key")) @@ -602,10 +590,7 @@ func TestFindParentForRewrite(t *testing.T) { C.template = add3 D.template = remove2 `, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{k1}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), + genesisTemplate(k1), optTemplate("add2", AUM{MessageKind: AUMAddKey, Key: &k2}), optTemplate("add3", AUM{MessageKind: AUMAddKey, Key: &k3}), optTemplate("remove2", AUM{MessageKind: AUMRemoveKey, KeyID: k2ID})) @@ -671,10 +656,7 @@ func TestMakeRetroactiveRevocation(t *testing.T) { C.template = add2 D.template = add3 `, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{k1}, - DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), + genesisTemplate(k1), optTemplate("add2", AUM{MessageKind: AUMAddKey, Key: &k2}), optTemplate("add3", AUM{MessageKind: AUMAddKey, Key: &k3})) diff --git a/tool/go-win.ps1 b/tool/go-win.ps1 deleted file mode 100644 index 49313ffbabee9..0000000000000 --- a/tool/go-win.ps1 +++ /dev/null @@ -1,64 +0,0 @@ -<# - go.ps1 – Tailscale Go toolchain fetching wrapper for Windows/PowerShell - â€ĸ Reads go.toolchain.rev one dir above this script - â€ĸ If the requested commit hash isn't cached, downloads and unpacks - https://github.com/tailscale/go/releases/download/build-${REV}/${OS}-${ARCH}.tar.gz - â€ĸ Finally execs the toolchain's "go" binary, forwarding all args & exit-code -#> - -param( - [Parameter(ValueFromRemainingArguments = $true)] - [string[]] $Args -) - -Set-StrictMode -Version Latest -$ErrorActionPreference = 'Stop' - -if ($env:CI -eq 'true' -and $env:NODEBUG -ne 'true') { - $VerbosePreference = 'Continue' -} - -$repoRoot = Resolve-Path (Join-Path $PSScriptRoot '..') -$REV = (Get-Content (Join-Path $repoRoot 'go.toolchain.rev') -Raw).Trim() - -if ([IO.Path]::IsPathRooted($REV)) { - $toolchain = $REV -} else { - if (-not [string]::IsNullOrWhiteSpace($env:TSGO_CACHE_ROOT)) { - $cacheRoot = $env:TSGO_CACHE_ROOT - } else { - $cacheRoot = Join-Path $env:USERPROFILE '.cache\tsgo' - } - - $toolchain = Join-Path $cacheRoot $REV - $marker = "$toolchain.extracted" - - if (-not (Test-Path $marker)) { - Write-Host "# Downloading Go toolchain $REV" -ForegroundColor Cyan - if (Test-Path $toolchain) { Remove-Item -Recurse -Force $toolchain } - - # Removing the marker file again (even though it shouldn't still exist) - # because the equivalent Bash script also does so (to guard against - # concurrent cache fills?). - # TODO(bradfitz): remove this and add some proper locking instead? - if (Test-Path $marker ) { Remove-Item -Force $marker } - - New-Item -ItemType Directory -Path $cacheRoot -Force | Out-Null - - $url = "https://github.com/tailscale/go/releases/download/build-$REV/windows-amd64.tar.gz" - $tgz = "$toolchain.tar.gz" - Invoke-WebRequest -Uri $url -OutFile $tgz -UseBasicParsing -ErrorAction Stop - - New-Item -ItemType Directory -Path $toolchain -Force | Out-Null - tar --strip-components=1 -xzf $tgz -C $toolchain - Remove-Item $tgz - Set-Content -Path $marker -Value $REV - } -} - -$goExe = Join-Path $toolchain 'bin\go.exe' -if (-not (Test-Path $goExe)) { throw "go executable not found at $goExe" } - -& $goExe @Args -exit $LASTEXITCODE - diff --git a/tool/go.cmd b/tool/go.cmd deleted file mode 100644 index b7b5d0483b972..0000000000000 --- a/tool/go.cmd +++ /dev/null @@ -1,36 +0,0 @@ -@echo off -rem Checking for PowerShell Core using PowerShell for Windows... -powershell -NoProfile -NonInteractive -Command "& {Get-Command -Name pwsh -ErrorAction Stop}" > NUL -if ERRORLEVEL 1 ( - rem Ask the user whether they should install the dependencies. Note that this - rem code path never runs in CI because pwsh is always explicitly installed. - - rem Time out after 5 minutes, defaulting to 'N' - choice /c yn /t 300 /d n /m "PowerShell Core is required. Install now" - if ERRORLEVEL 2 ( - echo Aborting due to unmet dependencies. - exit /b 1 - ) - - rem Check for a .NET Core runtime using PowerShell for Windows... - powershell -NoProfile -NonInteractive -Command "& {if (-not (dotnet --list-runtimes | Select-String 'Microsoft\.NETCore\.App' -Quiet)) {exit 1}}" > NUL - rem Install .NET Core if missing to provide PowerShell Core's runtime library. - if ERRORLEVEL 1 ( - rem Time out after 5 minutes, defaulting to 'N' - choice /c yn /t 300 /d n /m "PowerShell Core requires .NET Core for its runtime library. Install now" - if ERRORLEVEL 2 ( - echo Aborting due to unmet dependencies. - exit /b 1 - ) - - winget install --accept-package-agreements --id Microsoft.DotNet.Runtime.8 -e --source winget - ) - - rem Now install PowerShell Core. - winget install --accept-package-agreements --id Microsoft.PowerShell -e --source winget - if ERRORLEVEL 0 echo Please re-run this script within a new console session to pick up PATH changes. - rem Either way we didn't build, so return 1. - exit /b 1 -) - -pwsh -NoProfile -ExecutionPolicy Bypass "%~dp0..\tool\gocross\gocross-wrapper.ps1" %* diff --git a/tool/go.exe b/tool/go.exe new file mode 100755 index 0000000000000..f295d6ac8b436 Binary files /dev/null and b/tool/go.exe differ diff --git a/tool/go.exe.README.txt b/tool/go.exe.README.txt new file mode 100644 index 0000000000000..3f4988599b28f --- /dev/null +++ b/tool/go.exe.README.txt @@ -0,0 +1,20 @@ +What is go.exe, and why's a 32-bit x86 Windows binary checked into the repo? + +See https://github.com/tailscale/tailscale/pull/19256 + +In summary, our previous attempts to provide a version of ./tool/go (a +shell script) on Windows with PowerShell and cmd.exe both were +lacking. + +So now we we're regrettably checking in a binary to the tree. Its +source code is in ./tool/goexe. It's written in Rust without std so +it's very small (smaller than plenty of of our source code files!) and +it's 32-bit x86 so it runs on 32-bit x86, 64-bit x86, and arm64 Windows +where it's emulated. + +This binary is not required, but it's used by our build system and +people working on Tailscale who are used to being able to run +"./tool/go" and have it do the right hermetic thing, using the correct +Go toolchain. + + diff --git a/tool/gocross/exec_other.go b/tool/gocross/exec_other.go index 20e52aa8f9496..b9004b8d52c70 100644 --- a/tool/gocross/exec_other.go +++ b/tool/gocross/exec_other.go @@ -21,8 +21,7 @@ func doExec(cmd string, args []string, env []string) error { // Propagate ExitErrors within this func to give us similar semantics to // the Unix variant. - var ee *exec.ExitError - if errors.As(err, &ee) { + if ee, ok := errors.AsType[*exec.ExitError](err); ok { os.Exit(ee.ExitCode()) } diff --git a/tool/gocross/gocross_wrapper_test.go b/tool/gocross/gocross_wrapper_test.go index 7fc81207f6379..035b57c162b16 100644 --- a/tool/gocross/gocross_wrapper_test.go +++ b/tool/gocross/gocross_wrapper_test.go @@ -6,13 +6,26 @@ package main import ( + "bytes" + "go/version" "os" "os/exec" + "runtime" "strings" "testing" + + "tailscale.com/util/must" ) func TestGocrossWrapper(t *testing.T) { + if version.Compare(runtime.Version(), "go1.27") < 0 { + gitDir := must.Get(exec.Command("git", "rev-parse", "--git-dir").Output()) + gitCommonDir := must.Get(exec.Command("git", "rev-parse", "--git-common-dir").Output()) + if !bytes.Equal(gitDir, gitCommonDir) { + t.Skip("skipping within git worktree, see https://go.dev/issue/58218") + } + } + for i := range 2 { // once to build gocross; second to test it's cached cmd := exec.Command("./gocross-wrapper.sh", "version") cmd.Env = append(os.Environ(), "CI=true", "NOBASHDEBUG=false", "TS_USE_GOCROSS=1") // for "set -x" verbosity diff --git a/tool/gocross/gocross_wrapper_windows_test.go b/tool/gocross/gocross_wrapper_windows_test.go index ed565e15ad677..83f3e7b791f3e 100644 --- a/tool/gocross/gocross_wrapper_windows_test.go +++ b/tool/gocross/gocross_wrapper_windows_test.go @@ -4,13 +4,26 @@ package main import ( + "bytes" + "go/version" "os" "os/exec" + "runtime" "strings" "testing" + + "tailscale.com/util/must" ) func TestGocrossWrapper(t *testing.T) { + if version.Compare(runtime.Version(), "go1.27") < 0 { + gitDir := must.Get(exec.Command("git", "rev-parse", "--git-dir").Output()) + gitCommonDir := must.Get(exec.Command("git", "rev-parse", "--git-common-dir").Output()) + if !bytes.Equal(gitDir, gitCommonDir) { + t.Skip("skipping within git worktree, see https://go.dev/issue/58218") + } + } + for i := range 2 { // once to build gocross; second to test it's cached cmd := exec.Command("pwsh", "-NoProfile", "-ExecutionPolicy", "Bypass", ".\\gocross-wrapper.ps1", "version") cmd.Env = append(os.Environ(), "CI=true", "NOPWSHDEBUG=false", "TS_USE_GOCROSS=1") // for Set-PSDebug verbosity diff --git a/tool/goexe/.cargo/config.toml b/tool/goexe/.cargo/config.toml new file mode 100644 index 0000000000000..68874b76557b6 --- /dev/null +++ b/tool/goexe/.cargo/config.toml @@ -0,0 +1,5 @@ +[build] +target = "i686-pc-windows-gnu" + +[target.i686-pc-windows-gnu] +rustflags = ["-C", "link-args=-nostartfiles -lkernel32"] diff --git a/tool/goexe/.gitignore b/tool/goexe/.gitignore new file mode 100644 index 0000000000000..97342e3ae1840 --- /dev/null +++ b/tool/goexe/.gitignore @@ -0,0 +1,2 @@ +/target/ +/go.exe diff --git a/tool/goexe/Cargo.lock b/tool/goexe/Cargo.lock new file mode 100644 index 0000000000000..670b23d414852 --- /dev/null +++ b/tool/goexe/Cargo.lock @@ -0,0 +1,25 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "go" +version = "0.1.0" +dependencies = [ + "windows-sys", +] + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] diff --git a/tool/goexe/Cargo.toml b/tool/goexe/Cargo.toml new file mode 100644 index 0000000000000..f20ea6e9d1c9e --- /dev/null +++ b/tool/goexe/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "go" +version = "0.1.0" +edition = "2024" + +[dependencies] +windows-sys = { version = "0.61", features = [ + "Win32", + "Win32_System_LibraryLoader", + "Win32_System_Environment", + "Win32_Storage_FileSystem", + "Win32_System_Threading", + "Win32_Security", + "Win32_System_SystemInformation", + "Win32_System_IO", + "Win32_System_Console", +] } + +[profile.release] +opt-level = "z" +lto = true +codegen-units = 1 +panic = "abort" +strip = true diff --git a/tool/goexe/Makefile b/tool/goexe/Makefile new file mode 100644 index 0000000000000..a1f6f1f3bb3e3 --- /dev/null +++ b/tool/goexe/Makefile @@ -0,0 +1,28 @@ +# Builds tool/go.exe, a thin wrapper that execs the Tailscale Go +# toolchain without going through cmd.exe (which mangles ^ and other +# special characters in arguments). +# See https://github.com/tailscale/tailscale/issues/19255 +# +# Built as no_std Rust with raw Win32 API calls for minimal size (~17KB). +# The resulting go.exe is checked into the repo at tool/go.exe. +# +# Built as 32-bit x86 so one binary runs on x86, x64 (via WoW64), +# and ARM64 (via Windows x86 emulation). +# +# Requirements: +# rustup target add i686-pc-windows-gnu +# apt install gcc-mingw-w64-i686 (or equivalent) + +RUST_TARGET = i686-pc-windows-gnu + +.PHONY: all clean + +all: go.exe + +go.exe: src/main.rs Cargo.toml + cargo build --release --target $(RUST_TARGET) + cp target/$(RUST_TARGET)/release/go.exe $@ + +clean: + rm -f go.exe + rm -rf target diff --git a/tool/goexe/src/main.rs b/tool/goexe/src/main.rs new file mode 100644 index 0000000000000..27c7e2056a86b --- /dev/null +++ b/tool/goexe/src/main.rs @@ -0,0 +1,482 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//! A thin wrapper that finds and execs the Tailscale Go toolchain without +//! going through cmd.exe, avoiding its argument mangling (cmd.exe treats ^ +//! as an escape character, breaking -run "^$" and similar, and = signs +//! also cause issues in PowerShell→cmd.exe argument passing). +//! See https://github.com/tailscale/tailscale/issues/19255. +//! +//! This replaces tool/go.cmd. When PowerShell resolves `./tool/go`, it +//! prefers go.exe over go.cmd, so this binary is used automatically. +//! +//! Built as no_std with raw Win32 API calls for minimal binary size (~17KB). +//! Built as 32-bit x86 so one binary runs on x86, x64 (via WoW64), and +//! ARM64 (via Windows x86 emulation). +//! +//! The raw command line from GetCommandLineW is passed through directly to +//! CreateProcessW (after swapping out argv[0]), so arguments are never +//! parsed or re-escaped, preserving them exactly as the caller specified. + +#![no_std] +#![no_main] +#![windows_subsystem = "console"] +// Every function in this program calls raw Win32 FFI; requiring unsafe +// blocks inside each unsafe fn would be pure noise. +#![allow(unsafe_op_in_unsafe_fn)] + +use core::ptr::{null, null_mut}; +use windows_sys::w; +use windows_sys::Win32::Foundation::{CloseHandle, GENERIC_READ, GENERIC_WRITE, INVALID_HANDLE_VALUE}; +use windows_sys::Win32::Storage::FileSystem::{CreateDirectoryW, CreateFileW, DeleteFileW, GetFileAttributesW, ReadFile, WriteFile, CREATE_ALWAYS, FILE_SHARE_READ, INVALID_FILE_ATTRIBUTES, OPEN_EXISTING}; +use windows_sys::Win32::System::Console::{GetStdHandle, STD_ERROR_HANDLE, STD_INPUT_HANDLE, STD_OUTPUT_HANDLE}; +use windows_sys::Win32::System::Environment::{GetCommandLineW, GetEnvironmentVariableW, SetEnvironmentVariableW}; +use windows_sys::Win32::System::LibraryLoader::GetModuleFileNameW; +use windows_sys::Win32::System::SystemInformation::{GetNativeSystemInfo, PROCESSOR_ARCHITECTURE_AMD64, PROCESSOR_ARCHITECTURE_ARM64, PROCESSOR_ARCHITECTURE_INTEL}; +use windows_sys::Win32::System::Threading::{ExitProcess, STARTF_USESTDHANDLES, CreateProcessW, STARTUPINFOW, PROCESS_INFORMATION, WaitForSingleObject, INFINITE, GetExitCodeProcess}; + +/// Exit code used when this wrapper panics, to distinguish from child +/// process failures. +#[cfg(not(test))] +const EXIT_CODE_PANIC: u32 = 0xFE; + +// A fixed-capacity UTF-16 buffer for building null-terminated wide strings +// to pass to Win32 APIs. All Win32-facing methods automatically null-terminate. +// +// Callers push ASCII (&[u8]) or wide (&WBuf) content; the buffer handles +// the ASCII-to-UTF-16 widening internally, keeping encoding concerns in +// one place. + +struct WBuf { + buf: [u16; N], + len: usize, +} + +impl WBuf { + fn new() -> Self { + Self { + buf: [0; N], + len: 0, + } + } + + /// Null-terminated pointer for Win32 APIs. + fn as_ptr(&mut self) -> *const u16 { + self.buf[self.len] = 0; + self.buf.as_ptr() + } + + /// Mutable null-terminated pointer (for CreateProcessW's lpCommandLine). + fn as_mut_ptr(&mut self) -> *mut u16 { + self.buf[self.len] = 0; + self.buf.as_mut_ptr() + } + + /// Append ASCII bytes, widening each byte to UTF-16. + fn push_ascii(&mut self, s: &[u8]) -> &mut Self { + for &b in s { + self.buf[self.len] = b as u16; + self.len += 1; + } + self + } + + /// Append the contents of another WBuf. + fn push_wbuf(&mut self, other: &WBuf) -> &mut Self { + self.buf[self.len..self.len + other.len].copy_from_slice(&other.buf[..other.len]); + self.len += other.len; + self + } + + /// Append raw UTF-16 content from a pointer until null terminator. + /// Used for appending the tail of GetCommandLineW. + unsafe fn push_ptr(&mut self, mut p: *const u16) -> &mut Self { + loop { + let c = *p; + if c == 0 { + break; + } + self.buf[self.len] = c; + self.len += 1; + p = p.add(1); + } + self + } + + /// Find the last path separator (\ or /) and truncate to it, + /// effectively navigating to the parent directory. + fn pop_path_component(&mut self) -> bool { + let mut i = self.len; + while i > 0 { + i -= 1; + if self.buf[i] == b'\\' as u16 || self.buf[i] == b'/' as u16 { + self.len = i; + return true; + } + } + false + } + + /// Check whether a file exists at "\". + unsafe fn file_exists_with(&mut self, suffix: &[u8]) -> bool { + let saved = self.len; + self.push_ascii(suffix); + let result = GetFileAttributesW(self.as_ptr()) != INVALID_FILE_ATTRIBUTES; + self.len = saved; + result + } +} + +/// Check if an environment variable equals an expected ASCII value. +/// Neither name nor val should include a null terminator. +unsafe fn env_eq(name: &[u8], val: &[u8]) -> bool { + let mut name_w = WBuf::<64>::new(); + name_w.push_ascii(name); + let mut buf = [0u16; 64]; + let n = GetEnvironmentVariableW(name_w.as_ptr(), buf.as_mut_ptr(), buf.len() as u32) as usize; + if n != val.len() { + return false; + } + for (i, &b) in val.iter().enumerate() { + if buf[i] != b as u16 { + return false; + } + } + true +} + +/// Get an environment variable's value into a WBuf. +/// Returns the number of characters written (0 if not set). +unsafe fn get_env(name: &[u8], dst: &mut WBuf) -> usize { + let mut name_w = WBuf::<64>::new(); + name_w.push_ascii(name); + let n = GetEnvironmentVariableW( + name_w.as_ptr(), + dst.buf.as_mut_ptr(), + dst.buf.len() as u32, + ) as usize; + dst.len = n; + n +} + +/// C runtime entry point for MinGW/MSVC. Called before main() would be. +/// We use #[no_main] so we define this directly. +#[unsafe(no_mangle)] +pub extern "C" fn mainCRTStartup() -> ! { + unsafe { main_impl() } +} + +unsafe fn main_impl() -> ! { + // Get our own exe path, e.g. "C:\Users\...\tailscale\tool\go.exe". + let mut exe = WBuf::<4096>::new(); + exe.len = GetModuleFileNameW(null_mut(), exe.buf.as_mut_ptr(), exe.buf.len() as u32) as usize; + if exe.len == 0 { + die(b"GetModuleFileNameW failed\n"); + } + + // Walk up directories from our exe location to find the repo root, + // identified by the presence of "go.toolchain.rev". + exe.pop_path_component(); // strip filename, e.g. "...\tool" + let repo_root = loop { + if !exe.file_exists_with(b"\\go.toolchain.rev") { + if !exe.pop_path_component() { + die(b"could not find go.toolchain.rev\n"); + } + continue; + } + break WBuf::<4096> { + buf: exe.buf, + len: exe.len, + }; + }; + + // Read the toolchain revision hash from go.toolchain.rev (or + // go.toolchain.next.rev if TS_GO_NEXT=1). + let mut rev_path = WBuf::<4096>::new(); + rev_path.push_wbuf(&repo_root); + if env_eq(b"TS_GO_NEXT", b"1") { + rev_path.push_ascii(b"\\go.toolchain.next.rev"); + } else { + rev_path.push_ascii(b"\\go.toolchain.rev"); + } + + let mut rev_buf = [0u8; 256]; + let rev = read_file_trimmed(&mut rev_path, &mut rev_buf); + + // Build the toolchain path. The rev is normally a git hash, and + // the toolchain lives at %USERPROFILE%\.cache\tsgo\. + // If the rev starts with "/" or "\" it's an absolute path to a + // local toolchain (used for testing). + let mut toolchain = WBuf::<4096>::new(); + if rev.first() == Some(&b'/') || rev.first() == Some(&b'\\') { + toolchain.push_ascii(rev); + } else { + if get_env(b"USERPROFILE", &mut toolchain) == 0 { + die(b"USERPROFILE not set\n"); + } + toolchain.push_ascii(b"\\.cache\\tsgo\\"); + toolchain.push_ascii(rev); + } + + // If the toolchain hasn't been downloaded yet (no ".extracted" marker), + // download it. For TS_USE_GOCROSS=1, fall back to PowerShell since + // that path also needs to build gocross. + if !toolchain.file_exists_with(b".extracted") { + if env_eq(b"TS_USE_GOCROSS", b"1") { + fallback_pwsh(&repo_root); + } + download_toolchain(&toolchain, rev); + } + + // Build the path to the real go.exe binary inside the toolchain, + // or to gocross.exe if TS_USE_GOCROSS=1. + let mut go_exe = WBuf::<4096>::new(); + if env_eq(b"TS_USE_GOCROSS", b"1") { + go_exe.push_wbuf(&repo_root).push_ascii(b"\\gocross.exe"); + } else { + go_exe.push_wbuf(&toolchain).push_ascii(b"\\bin\\go.exe"); + } + + // Unset GOROOT to avoid breaking builds that depend on our Go + // fork's patches (e.g. net/). The Go toolchain sets GOROOT + // internally from its own location. + SetEnvironmentVariableW(w!("GOROOT"), null()); + + // Build the new command line by replacing argv[0] with the real + // go.exe path. We take the raw command line from GetCommandLineW + // and pass the args portion through untouched — no parsing or + // re-escaping — so special characters like ^ and = survive intact. + let raw_cmd = GetCommandLineW(); + let args_tail = skip_argv0(raw_cmd); + + let mut cmd = WBuf::<32768>::new(); + cmd.push_ascii(b"\""); + cmd.push_wbuf(&go_exe); + cmd.push_ascii(b"\""); + cmd.push_ptr(args_tail); + + // Exec: create the child process, wait for it, and exit with its code. + let code = run_and_wait(go_exe.as_ptr(), &mut cmd, null()); + ExitProcess(code); +} + +/// Download the Go toolchain tarball from GitHub and extract it. +/// Uses curl.exe and tar.exe which ship with Windows 10+. +unsafe fn download_toolchain(toolchain: &WBuf<4096>, rev: &[u8]) { + stderr(b"# Downloading Go toolchain "); + stderr(rev); + stderr(b"\n"); + + // Create parent directories (%USERPROFILE%\.cache\tsgo). + // CreateDirectoryW is fine if the dir already exists. + let mut dir = WBuf::<4096>::new(); + get_env(b"USERPROFILE", &mut dir); + dir.push_ascii(b"\\.cache"); + CreateDirectoryW(dir.as_ptr(), null()); + dir.push_ascii(b"\\tsgo"); + CreateDirectoryW(dir.as_ptr(), null()); + + // Create the toolchain directory itself. + let mut tc_dir = WBuf::<4096>::new(); + tc_dir.push_wbuf(toolchain); + CreateDirectoryW(tc_dir.as_ptr(), null()); + + // Detect host architecture via GetNativeSystemInfo (gives real arch + // even from a WoW64 32-bit process). + let mut si = core::mem::zeroed(); + GetNativeSystemInfo(&mut si); + + let arch: &[u8] = match si.Anonymous.Anonymous.wProcessorArchitecture as u16 { + PROCESSOR_ARCHITECTURE_AMD64 => b"amd64", + PROCESSOR_ARCHITECTURE_ARM64 => b"arm64", + PROCESSOR_ARCHITECTURE_INTEL => b"386", + _ => die(b"unsupported architecture\n"), + }; + + // Build tarball path: .tar.gz + let mut tgz = WBuf::<4096>::new(); + tgz.push_wbuf(toolchain).push_ascii(b".tar.gz"); + + // Build URL: + // https://github.com/tailscale/go/releases/download/build-/windows-.tar.gz + let mut url = [0u8; 512]; + let mut u = 0; + for part in [ + b"https://github.com/tailscale/go/releases/download/build-" as &[u8], + rev, + b"/windows-", + arch, + b".tar.gz", + ] { + url[u..u + part.len()].copy_from_slice(part); + u += part.len(); + } + + // Run: curl.exe -fsSL -o + let mut cmd = WBuf::<32768>::new(); + cmd.push_ascii(b"curl.exe -fsSL -o \""); + cmd.push_wbuf(&tgz); + cmd.push_ascii(b"\" "); + cmd.push_ascii(&url[..u]); + + let code = run_and_wait(null(), &mut cmd, null()); + if code != 0 { + die(b"curl failed to download Go toolchain\n"); + } + + // Run: tar.exe --strip-components=1 -xf + // with working directory set to the toolchain dir. + let mut cmd = WBuf::<32768>::new(); + cmd.push_ascii(b"tar.exe --strip-components=1 -xf \""); + cmd.push_wbuf(&tgz); + cmd.push_ascii(b"\""); + + let code = run_and_wait(null(), &mut cmd, tc_dir.as_ptr()); + if code != 0 { + die(b"tar failed to extract Go toolchain\n"); + } + + // Write the .extracted marker file. + let mut marker = WBuf::<4096>::new(); + marker.push_wbuf(toolchain).push_ascii(b".extracted"); + let fh = CreateFileW(marker.as_ptr(), GENERIC_WRITE, 0, null(), CREATE_ALWAYS, 0, null_mut()); + if fh != INVALID_HANDLE_VALUE { + let mut written: u32 = 0; + WriteFile(fh, rev.as_ptr(), rev.len() as u32, &mut written, null_mut()); + CloseHandle(fh); + } + + // Clean up the tarball. + DeleteFileW(tgz.as_ptr()); +} + +/// Spawn a child process, wait for it, and return its exit code. +/// If app is null, CreateProcessW searches PATH using the command line. +/// If dir is null, the child inherits the current directory. +unsafe fn run_and_wait(app: *const u16, cmd: &mut WBuf<32768>, dir: *const u16) -> u32 { + let si = STARTUPINFOW { + cb: size_of::() as u32, + dwFlags: STARTF_USESTDHANDLES, + hStdInput: GetStdHandle(STD_INPUT_HANDLE), + hStdOutput: GetStdHandle(STD_OUTPUT_HANDLE), + hStdError: GetStdHandle(STD_ERROR_HANDLE), + ..Default::default() + }; + let mut pi = PROCESS_INFORMATION::default(); + + if CreateProcessW( + app, + cmd.as_mut_ptr(), + null(), + null(), + 1, // bInheritHandles = TRUE + 0, + null(), + dir, + &si, + &mut pi, + ) == 0 + { + die(b"CreateProcess failed\n"); + } + + WaitForSingleObject(pi.hProcess, INFINITE); + let mut code: u32 = 1; + GetExitCodeProcess(pi.hProcess, &mut code); + CloseHandle(pi.hProcess); + CloseHandle(pi.hThread); + code +} + +/// Fall back to PowerShell for the full bootstrap flow (downloading the +/// toolchain, optionally building gocross, and then running go): +/// pwsh -NoProfile -ExecutionPolicy Bypass "\tool\gocross\gocross-wrapper.ps1" +unsafe fn fallback_pwsh(repo_root: &WBuf<4096>) -> ! { + let raw_cmd = GetCommandLineW(); + let args_tail = skip_argv0(raw_cmd); + + let mut cmd = WBuf::<32768>::new(); + cmd.push_ascii(b"pwsh -NoProfile -ExecutionPolicy Bypass \""); + cmd.push_wbuf(repo_root); + cmd.push_ascii(b"\\tool\\gocross\\gocross-wrapper.ps1\""); + cmd.push_ptr(args_tail); + + // Pass null for lpApplicationName so CreateProcessW searches PATH for "pwsh". + let code = run_and_wait(null(), &mut cmd, null()); + ExitProcess(code); +} + +/// Read an entire file (expected to be small ASCII, e.g. a git hash) into buf, +/// and return the trimmed content as a byte slice. +unsafe fn read_file_trimmed<'a, const N: usize>( + path: &mut WBuf, + buf: &'a mut [u8], +) -> &'a [u8] { + let h = CreateFileW( + path.as_ptr(), + GENERIC_READ, + FILE_SHARE_READ, + null(), + OPEN_EXISTING, + 0, + null_mut(), + ); + if h == INVALID_HANDLE_VALUE { + die(b"cannot open go.toolchain.rev\n"); + } + let mut n: u32 = 0; + ReadFile(h, buf.as_mut_ptr(), buf.len() as u32, &mut n, null_mut()); + CloseHandle(h); + + let s = &buf[..n as usize]; + let start = s.iter().position(|b| !b.is_ascii_whitespace()).unwrap_or(s.len()); + let end = s.iter().rposition(|b| !b.is_ascii_whitespace()).map_or(start, |i| i + 1); + &s[start..end] +} + +/// Advance past argv[0] in a raw Windows command line string. +/// +/// Windows command lines are a single string; argv[0] may be quoted +/// (if the path contains spaces) or unquoted. +/// See https://learn.microsoft.com/en-us/cpp/c-language/parsing-c-command-line-arguments +unsafe fn skip_argv0(cmd: *const u16) -> *const u16 { + let mut p = cmd; + if *p == b'"' as u16 { + // Quoted argv[0]: advance past closing quote. + p = p.add(1); + while *p != 0 && *p != b'"' as u16 { + p = p.add(1); + } + if *p == b'"' as u16 { + p = p.add(1); + } + } else { + // Unquoted argv[0]: advance to first whitespace. + while *p != 0 && *p != b' ' as u16 && *p != b'\t' as u16 { + p = p.add(1); + } + } + // Return pointer to the rest (typically starts with a space before + // the first real argument, or is empty if there are no arguments). + p +} + +/// Write bytes to stderr. +unsafe fn stderr(msg: &[u8]) { + let h = GetStdHandle(STD_ERROR_HANDLE); + let mut n: u32 = 0; + WriteFile(h, msg.as_ptr(), msg.len() as u32, &mut n, null_mut()); +} + +/// Write an error message to stderr and terminate with exit code 1. +unsafe fn die(msg: &[u8]) -> ! { + stderr(b"tool/go: "); + stderr(msg); + ExitProcess(1); +} + +#[cfg(not(test))] +#[panic_handler] +fn panic(_: &core::panic::PanicInfo) -> ! { + unsafe { ExitProcess(EXIT_CODE_PANIC) } +} diff --git a/tool/listpkgs/listpkgs.go b/tool/listpkgs/listpkgs.go index 1c2dda257a7ca..b29db94b1f5c4 100644 --- a/tool/listpkgs/listpkgs.go +++ b/tool/listpkgs/listpkgs.go @@ -10,9 +10,12 @@ import ( "flag" "fmt" "go/build/constraint" + "io/fs" "log" "os" + "path/filepath" "slices" + "sort" "strings" "sync" @@ -27,11 +30,18 @@ var ( withoutTagsAnyStr = flag.String("without-tags-any", "", "if non-empty, a comma-separated list of build constraints to exclude (a package will be omitted if it contains any of these build tags)") shard = flag.String("shard", "", "if non-empty, a string of the form 'N/M' to only print packages in shard N of M (e.g. '1/3', '2/3', '3/3/' for different thirds of the list)") affectedByTag = flag.String("affected-by-tag", "", "if non-empty, only list packages whose test binary would be affected by the presence or absence of this build tag") + hasRootTests = flag.Bool("has-root-tests", false, "list packages (as ./relative/path) containing _test.go files that call tstest.RequireRoot") + hasGoGenerate = flag.Bool("has-go-generate", false, "only list packages that contain at least one //go:generate directive") ) func main() { flag.Parse() + if *hasRootTests { + printRootTestPkgs() + return + } + patterns := flag.Args() if len(patterns) == 0 { flag.Usage() @@ -112,6 +122,9 @@ Pkg: continue Pkg } } + if *hasGoGenerate && !pkgHasGoGenerate(pkg) { + continue Pkg + } matches++ if *shard != "" { @@ -281,3 +294,123 @@ func fileMentionsTag(filename, tag string) (bool, error) { } return tags[tag], nil } + +// pkgHasGoGenerate reports whether any source file in pkg contains a +// //go:generate directive. +func pkgHasGoGenerate(pkg *packages.Package) bool { + // Include IgnoredFiles so directives behind build constraints are still + // found; the caller can narrow by tag via -with-tags-all/-without-tags-any + // if they care. + all := slices.Concat(pkg.CompiledGoFiles, pkg.OtherFiles, pkg.IgnoredFiles) + for _, name := range all { + ok, err := fileHasGoGenerate(name) + if err != nil { + log.Printf("reading %s: %v", name, err) + continue + } + if ok { + return true + } + } + return false +} + +var ( + goGenerateMu sync.Mutex + goGenerate = map[string]bool{} // abs path -> whether file has //go:generate +) + +func fileHasGoGenerate(filename string) (bool, error) { + goGenerateMu.Lock() + v, ok := goGenerate[filename] + goGenerateMu.Unlock() + if ok { + return v, nil + } + + f, err := os.Open(filename) + if err != nil { + return false, err + } + defer f.Close() + + has := false + s := bufio.NewScanner(f) + for s.Scan() { + // go:generate directives must start at column 1 (no leading + // whitespace) to be recognized by the go tool. + if strings.HasPrefix(s.Text(), "//go:generate") { + has = true + break + } + } + if err := s.Err(); err != nil { + return false, fmt.Errorf("reading %s: %w", filename, err) + } + + goGenerateMu.Lock() + goGenerate[filename] = has + goGenerateMu.Unlock() + return has, nil +} + +// printRootTestPkgs walks the current directory tree looking for _test.go +// files that contain "tstest.RequireRoot" and prints the unique package +// directories as ./relative/path. +func printRootTestPkgs() { + root, err := os.Getwd() + if err != nil { + log.Fatal(err) + } + seen := map[string]bool{} + var dirs []string + filepath.WalkDir(root, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return nil + } + name := d.Name() + if d.IsDir() { + // Skip hidden dirs and common non-Go dirs. + if strings.HasPrefix(name, ".") || name == "vendor" || name == "node_modules" { + return filepath.SkipDir + } + return nil + } + if !strings.HasSuffix(name, "_test.go") { + return nil + } + rel, err := filepath.Rel(root, path) + if err != nil { + return nil + } + dir := filepath.Dir(rel) + if seen[dir] { + return nil // already found a match in this dir + } + if fileContains(path, "tstest.RequireRoot") { + seen[dir] = true + dirs = append(dirs, dir) + } + return nil + }) + sort.Strings(dirs) + for _, d := range dirs { + fmt.Println("./" + filepath.ToSlash(d)) + } +} + +// fileContains reports whether the file at path contains the given substring. +func fileContains(path, substr string) bool { + f, err := os.Open(path) + if err != nil { + return false + } + defer f.Close() + s := bufio.NewScanner(f) + for s.Scan() { + if strings.Contains(s.Text(), substr) { + return true + } + } + return false +} diff --git a/tool/updateflakes/updateflakes.go b/tool/updateflakes/updateflakes.go new file mode 100644 index 0000000000000..e2a572d1278ba --- /dev/null +++ b/tool/updateflakes/updateflakes.go @@ -0,0 +1,264 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// updateflakes regenerates flakehashes.json, the file that records +// the Nix SRI hashes for the Go module vendor tree and the Tailscale +// Go toolchain tarball. +// +// The file is content-addressed: each block records the input +// fingerprint that produced its SRI, and updateflakes only +// regenerates a block when the current input differs from the +// recorded fingerprint. As a result, repeat runs with no input +// changes are no-ops. +// +// Run from the repo root: +// +// ./tool/go run ./tool/updateflakes +package main + +import ( + "bytes" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" + "flag" + "fmt" + "io/fs" + "log" + "net/http" + "os" + "os/exec" + "path/filepath" + "strings" + + "golang.org/x/sync/errgroup" + "tailscale.com/cmd/nardump/nardump" +) + +const ( + hashesFile = "flakehashes.json" + goModFile = "go.mod" + goSumFile = "go.sum" + toolchainRevFile = "go.toolchain.rev" + flakeNixFile = "flake.nix" + shellNixFile = "shell.nix" + cacheBustPrefix = "# nix-direnv cache busting line:" +) + +// FlakeHashes is the on-disk schema of flakehashes.json. It is also +// consumed directly by flake.nix via builtins.fromJSON, so changes +// to the JSON shape must be coordinated with flake.nix. +type FlakeHashes struct { + Toolchain ToolchainHash `json:"toolchain"` + Vendor VendorHash `json:"vendor"` +} + +// ToolchainHash records the SRI of the Tailscale Go toolchain +// tarball. Rev is the value in go.toolchain.rev that produced SRI. +type ToolchainHash struct { + Rev string `json:"rev"` + SRI string `json:"sri"` +} + +// VendorHash records the SRI of `go mod vendor` output. GoModSum is a +// fingerprint of go.mod and go.sum that produced SRI. +type VendorHash struct { + GoModSum string `json:"goModSum"` + SRI string `json:"sri"` +} + +func main() { + flag.Parse() + if err := run(); err != nil { + log.Fatal(err) + } +} + +func run() error { + have, err := loadHashes() + if err != nil { + return err + } + want := have + + rev, err := readTrim(toolchainRevFile) + if err != nil { + return err + } + wantToolchain := have.Toolchain.Rev != rev || have.Toolchain.SRI == "" + + goModSum, err := goModFingerprint() + if err != nil { + return err + } + wantVendor := have.Vendor.GoModSum != goModSum || have.Vendor.SRI == "" + + var ( + newToolchain ToolchainHash + newVendor VendorHash + ) + var g errgroup.Group + if wantToolchain { + g.Go(func() error { + sri, err := hashToolchain(rev) + if err != nil { + return err + } + newToolchain = ToolchainHash{Rev: rev, SRI: sri} + return nil + }) + } + if wantVendor { + g.Go(func() error { + sri, err := hashVendor() + if err != nil { + return err + } + newVendor = VendorHash{GoModSum: goModSum, SRI: sri} + return nil + }) + } + if err := g.Wait(); err != nil { + return err + } + if wantToolchain { + want.Toolchain = newToolchain + } + if wantVendor { + want.Vendor = newVendor + } + + if want != have { + if err := writeHashes(want); err != nil { + return err + } + } + + // nix-direnv only watches the top-level nix files for changes, + // so when a referenced hash changes we must also tickle + // flake.nix and shell.nix to force re-evaluation. + for _, f := range []string{flakeNixFile, shellNixFile} { + if err := updateCacheBust(f, want.Vendor.SRI); err != nil { + return err + } + } + return nil +} + +func loadHashes() (FlakeHashes, error) { + var h FlakeHashes + data, err := os.ReadFile(hashesFile) + if errors.Is(err, fs.ErrNotExist) { + return h, nil + } + if err != nil { + return h, err + } + if err := json.Unmarshal(data, &h); err != nil { + return h, fmt.Errorf("parse %s: %w", hashesFile, err) + } + return h, nil +} + +func writeHashes(h FlakeHashes) error { + b, err := json.MarshalIndent(h, "", " ") + if err != nil { + return err + } + b = append(b, '\n') + return os.WriteFile(hashesFile, b, 0644) +} + +func readTrim(path string) (string, error) { + b, err := os.ReadFile(path) + if err != nil { + return "", err + } + return strings.TrimSpace(string(b)), nil +} + +// goModFingerprint returns a content fingerprint of go.mod and go.sum +// that changes whenever either file changes. +func goModFingerprint() (string, error) { + h := sha256.New() + for _, f := range []string{goModFile, goSumFile} { + b, err := os.ReadFile(f) + if err != nil { + return "", err + } + fmt.Fprintf(h, "%s %d\n", f, len(b)) + h.Write(b) + } + return "sha256-" + base64.StdEncoding.EncodeToString(h.Sum(nil)), nil +} + +func hashVendor() (string, error) { + out, err := os.MkdirTemp("", "nar-vendor-") + if err != nil { + return "", err + } + // `go mod vendor -o` requires the destination to not already exist. + if err := os.Remove(out); err != nil { + return "", err + } + defer os.RemoveAll(out) + + cmd := exec.Command("./tool/go", "mod", "vendor", "-o", out) + cmd.Env = append(os.Environ(), "GOWORK=off") + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return "", fmt.Errorf("go mod vendor: %w", err) + } + return nardump.SRI(os.DirFS(out)) +} + +func hashToolchain(rev string) (string, error) { + out, err := os.MkdirTemp("", "nar-toolchain-") + if err != nil { + return "", err + } + defer os.RemoveAll(out) + + url := fmt.Sprintf("https://github.com/tailscale/go/archive/%s.tar.gz", rev) + resp, err := http.Get(url) + if err != nil { + return "", fmt.Errorf("fetching %s: %w", url, err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("fetching %s: %s", url, resp.Status) + } + + tar := exec.Command("tar", "-xz", "-C", out) + tar.Stdin = resp.Body + tar.Stderr = os.Stderr + if err := tar.Run(); err != nil { + return "", fmt.Errorf("extracting toolchain tarball: %w", err) + } + return nardump.SRI(os.DirFS(filepath.Join(out, "go-"+rev))) +} + +// updateCacheBust rewrites the "# nix-direnv cache busting line" +// in path to embed sri so nix-direnv re-evaluates when the SRI +// changes. The line lives at end of file, so walk in reverse. +func updateCacheBust(path, sri string) error { + b, err := os.ReadFile(path) + if err != nil { + return err + } + want := []byte(cacheBustPrefix + " " + sri) + lines := bytes.Split(b, []byte("\n")) + for i := len(lines) - 1; i >= 0; i-- { + line := lines[i] + if !bytes.HasPrefix(line, []byte(cacheBustPrefix)) { + continue + } + if bytes.Equal(line, want) { + return nil + } + lines[i] = want + return os.WriteFile(path, bytes.Join(lines, []byte("\n")), 0644) + } + return fmt.Errorf("%s: missing %q line", path, cacheBustPrefix) +} diff --git a/tsconsensus/authorization_test.go b/tsconsensus/authorization_test.go index 0f7a4e5958595..72c920972ebb5 100644 --- a/tsconsensus/authorization_test.go +++ b/tsconsensus/authorization_test.go @@ -68,14 +68,17 @@ func authForPeers(self *ipnstate.PeerStatus, peers []*ipnstate.PeerStatus) *auth func TestAuthRefreshErrorsNotRunning(t *testing.T) { tests := []struct { + name string in *ipnstate.Status expected string }{ { + name: "no-status", in: nil, expected: "no status", }, { + name: "ts-server-not-running", in: &ipnstate.Status{ BackendState: "NeedsMachineAuth", }, @@ -84,7 +87,7 @@ func TestAuthRefreshErrorsNotRunning(t *testing.T) { } for _, tt := range tests { - t.Run(tt.expected, func(t *testing.T) { + t.Run(tt.name, func(t *testing.T) { ctx := t.Context() a := authForStatus(tt.in) err := a.Refresh(ctx) @@ -127,22 +130,22 @@ func TestAuthAllowsHost(t *testing.T) { expected bool }{ { - name: "tagged with different tag", + name: "tagged-different-tag", peerStatus: peers[0], expected: false, }, { - name: "not tagged", + name: "not-tagged", peerStatus: peers[1], expected: false, }, { - name: "tags includes testTag", + name: "tags-include-testTag", peerStatus: peers[2], expected: true, }, { - name: "only tag is testTag", + name: "only-testTag", peerStatus: peers[3], expected: true, }, @@ -201,12 +204,12 @@ func TestAuthSelfAllowed(t *testing.T) { expected bool }{ { - name: "self has different tag", + name: "self-different-tag", in: []string{"woo"}, expected: false, }, { - name: "selfs tags include testTag", + name: "self-tags-include-testTag", in: []string{"woo", testTag}, expected: true, }, diff --git a/tsconsensus/monitor.go b/tsconsensus/monitor.go index cc5ac812c49d9..bf7410d0df332 100644 --- a/tsconsensus/monitor.go +++ b/tsconsensus/monitor.go @@ -12,7 +12,6 @@ import ( "net/http" "slices" - "tailscale.com/ipn" "tailscale.com/ipn/ipnstate" "tailscale.com/tsnet" "tailscale.com/util/dnsname" @@ -85,7 +84,7 @@ func (m *monitor) handleSummaryStatus(w http.ResponseWriter, r *http.Request) { lines = append(lines, fmt.Sprintf("%s\t\t%d\t%d\t%t", name, p.RxBytes, p.TxBytes, p.Active)) } } - _, err = w.Write([]byte(fmt.Sprintf("RaftState: %s\n", s.RaftState))) + _, err = w.Write(fmt.Appendf(nil, "RaftState: %s\n", s.RaftState)) if err != nil { log.Printf("monitor: error writing status: %v", err) return @@ -93,7 +92,7 @@ func (m *monitor) handleSummaryStatus(w http.ResponseWriter, r *http.Request) { slices.Sort(lines) for _, ln := range lines { - _, err = w.Write([]byte(fmt.Sprintf("%s\n", ln))) + _, err = w.Write(fmt.Appendf(nil, "%s\n", ln)) if err != nil { log.Printf("monitor: error writing status: %v", err) return @@ -108,24 +107,16 @@ func (m *monitor) handleNetmap(w http.ResponseWriter, r *http.Request) { http.Error(w, "", http.StatusInternalServerError) return } - watcher, err := lc.WatchIPNBus(r.Context(), ipn.NotifyInitialNetMap) + st, err := lc.Status(r.Context()) if err != nil { - log.Printf("monitor: error WatchIPNBus: %v", err) - http.Error(w, "", http.StatusInternalServerError) - return - } - defer watcher.Close() - - n, err := watcher.Next() - if err != nil { - log.Printf("monitor: error watcher.Next: %v", err) + log.Printf("monitor: error fetching status: %v", err) http.Error(w, "", http.StatusInternalServerError) return } encoder := json.NewEncoder(w) encoder.SetIndent("", "\t") - if err := encoder.Encode(n); err != nil { - log.Printf("monitor: error encoding netmap: %v", err) + if err := encoder.Encode(st); err != nil { + log.Printf("monitor: error encoding status: %v", err) return } } diff --git a/tsconsensus/tsconsensus_test.go b/tsconsensus/tsconsensus_test.go index 8897db119c467..3236ef680a8e9 100644 --- a/tsconsensus/tsconsensus_test.go +++ b/tsconsensus/tsconsensus_test.go @@ -296,7 +296,7 @@ func startNodesAndWaitForPeerStatus(t testing.TB, ctx context.Context, clusterTa keysToTag := make([]key.NodePublic, nNodes) localClients := make([]*tailscale.LocalClient, nNodes) control, controlURL := startControl(t) - for i := 0; i < nNodes; i++ { + for i := range nNodes { ts, key, _ := startNode(t, ctx, controlURL, fmt.Sprintf("node %d", i)) ps[i] = &participant{ts: ts, key: key} keysToTag[i] = key @@ -353,7 +353,7 @@ func createConsensusCluster(t testing.TB, ctx context.Context, clusterTag string } fxRaftConfigContainsAll := func() bool { - for i := 0; i < len(participants); i++ { + for i := range participants { fut := participants[i].c.raft.GetConfiguration() err = fut.Error() if err != nil { @@ -618,8 +618,8 @@ func TestOnlyTaggedPeersCanDialRaftPort(t *testing.T) { } isNetErr := func(err error) bool { - var netErr net.Error - return errors.As(err, &netErr) + _, ok := errors.AsType[net.Error](err) + return ok } err := getErrorFromTryingToSend(untaggedNode) diff --git a/tsd/tsd.go b/tsd/tsd.go index 9d79334d68e2b..615c9c0e741c7 100644 --- a/tsd/tsd.go +++ b/tsd/tsd.go @@ -18,6 +18,7 @@ package tsd import ( + "crypto/x509" "fmt" "reflect" @@ -63,6 +64,12 @@ type System struct { PolicyClient SubSystem[policyclient.Client] HealthTracker SubSystem[*health.Tracker] + // ExtraRootCAs, if non-nil, specifies additional trusted root CAs + // beyond the system roots. On Android, this includes user-installed + // CA certificates that Go's crypto/x509 does not see. + // It is plumbed through to tlsdial.Config via tls.Config.RootCAs. + ExtraRootCAs *x509.CertPool + // InitialConfig is initial server config, if any. // It is nil if the node is not in declarative mode. // This value is never updated after startup. @@ -226,8 +233,7 @@ func (p *SubSystem[T]) Set(v T) { return } - var z *T - panic(fmt.Sprintf("%v is already set", reflect.TypeOf(z).Elem().String())) + panic(fmt.Sprintf("%v is already set", reflect.TypeFor[T]().String())) } p.v = v p.set = true @@ -236,8 +242,7 @@ func (p *SubSystem[T]) Set(v T) { // Get returns the value of p, panicking if it hasn't been set. func (p *SubSystem[T]) Get() T { if !p.set { - var z *T - panic(fmt.Sprintf("%v is not set", reflect.TypeOf(z).Elem().String())) + panic(fmt.Sprintf("%v is not set", reflect.TypeFor[T]().String())) } return p.v } diff --git a/tsnet/README.md b/tsnet/README.md new file mode 100644 index 0000000000000..f9a96af006569 --- /dev/null +++ b/tsnet/README.md @@ -0,0 +1,109 @@ + + +# tsnet + +[![Go Reference](https://pkg.go.dev/badge/tailscale.com/tsnet.svg)](https://pkg.go.dev/tailscale.com/tsnet) + +Package tsnet embeds a Tailscale node directly into a Go program, allowing it to join a tailnet and accept or dial connections without running a separate tailscaled daemon or requiring any system-level configuration. + +## Overview + +Normally, Tailscale runs as a background system service (tailscaled) that manages a virtual network interface for the whole machine. tsnet takes a different approach: it runs a fully self-contained Tailscale node inside your process using a userspace TCP/IP stack (gVisor). This means: + + - No root privileges required. + - No system daemons to install or manage. + - Multiple independent Tailscale nodes can run within a single binary. + - The node's [Tailscale identity](https://tailscale.com/docs/concepts/tailscale-identity) and state are stored in a directory you control. + +The core type is [Server](https://pkg.go.dev/tailscale.com/tsnet#Server), which represents one embedded Tailscale node. Calling [Server.Listen](https://pkg.go.dev/tailscale.com/tsnet#Server.Listen) or [Server.Dial](https://pkg.go.dev/tailscale.com/tsnet#Server.Dial) routes traffic exclusively over the tailnet. The standard library's [net.Listener](https://pkg.go.dev/net#Listener) and [net.Conn](https://pkg.go.dev/net#Conn) interfaces are returned, so any existing Go HTTP server, gRPC server, or other net-based code works without modification. + +## Usage + + import "tailscale.com/tsnet" + + s := &tsnet.Server{ + Hostname: "my-service", + AuthKey: os.Getenv("TS_AUTHKEY"), + } + defer s.Close() + + ln, err := s.Listen("tcp", ":80") + if err != nil { + log.Fatal(err) + } + log.Fatal(http.Serve(ln, myHandler)) + +On first run, if no [Server.AuthKey](https://pkg.go.dev/tailscale.com/tsnet#Server.AuthKey) is provided and the node is not already enrolled, the server logs an authentication URL. Open it in a browser to add the node to your tailnet. + +## Authentication + +A [Server](https://pkg.go.dev/tailscale.com/tsnet#Server) authenticates using, in order of precedence: + + 1. [Server.AuthKey](https://pkg.go.dev/tailscale.com/tsnet#Server.AuthKey). + + 2. The TS\_AUTHKEY environment variable. + + 3. The TS\_AUTH\_KEY environment variable. + + 4. An OAuth client secret ([Server.ClientSecret](https://pkg.go.dev/tailscale.com/tsnet#Server.ClientSecret) or TS\_CLIENT\_SECRET), used to mint an auth key. + + 5. Workload identity federation ([Server.ClientID](https://pkg.go.dev/tailscale.com/tsnet#Server.ClientID) plus [Server.IDToken](https://pkg.go.dev/tailscale.com/tsnet#Server.IDToken) or [Server.Audience](https://pkg.go.dev/tailscale.com/tsnet#Server.Audience)). Available only if the program imports the feature: + + import \_ "tailscale.com/feature/identityfederation" + + The feature is not linked by default to keep the AWS SDK and other cloud-provider dependencies out of programs that don't use workload identity federation. + + 6. An interactive login URL printed to [Server.UserLogf](https://pkg.go.dev/tailscale.com/tsnet#Server.UserLogf). + +If the node is already enrolled (state found in [Server.Store](https://pkg.go.dev/tailscale.com/tsnet#Server.Store)), the auth key is ignored unless TSNET\_FORCE\_LOGIN=1 is set. + +## Identifying callers + +Use the WhoIs method on the client returned by [Server.LocalClient](https://pkg.go.dev/tailscale.com/tsnet#Server.LocalClient) to identify who is making a request: + + lc, _ := srv.LocalClient() + http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + who, err := lc.WhoIs(r.Context(), r.RemoteAddr) + if err != nil { + http.Error(w, err.Error(), 500) + return + } + fmt.Fprintf(w, "Hello, %s!", who.UserProfile.LoginName) + })) + +## Tailscale Funnel + +[Server.ListenFunnel](https://pkg.go.dev/tailscale.com/tsnet#Server.ListenFunnel) exposes your service on the public internet. [Tailscale Funnel](https://tailscale.com/docs/features/tailscale-funnel) currently supports TCP on ports 443, 8443, and 10000. HTTPS must be enabled in the Tailscale admin console. + + ln, err := srv.ListenFunnel("tcp", ":443") + // ln is a TLS listener; connections can come from anywhere on the + // internet as well as from your tailnet. + + // To restrict to public traffic only: + ln, err = srv.ListenFunnel("tcp", ":443", tsnet.FunnelOnly()) + +## Tailscale Services + +[Server.ListenService](https://pkg.go.dev/tailscale.com/tsnet#Server.ListenService) advertises the node as a host for a named [Tailscale Service](https://tailscale.com/docs/features/tailscale-services). The node must use a tag-based identity. To advertise multiple ports, call ListenService once per port. + + srv.AdvertiseTags = []string{"tag:myservice"} + + ln, err := srv.ListenService("svc:my-service", tsnet.ServiceModeHTTP{ + HTTPS: true, + Port: 443, + }) + log.Printf("Listening on https://%s", ln.FQDN) + +## Running multiple nodes in one process + +Each [Server](https://pkg.go.dev/tailscale.com/tsnet#Server) instance is an independent node. Give each a unique [Server.Dir](https://pkg.go.dev/tailscale.com/tsnet#Server.Dir) and [Server.Hostname](https://pkg.go.dev/tailscale.com/tsnet#Server.Hostname): + + for _, name := range []string{"frontend", "backend"} { + srv := &tsnet.Server{ + Hostname: name, + Dir: filepath.Join(baseDir, name), + AuthKey: os.Getenv("TS_AUTHKEY"), + Ephemeral: true, + } + srv.Start() + } diff --git a/tsnet/depaware.txt b/tsnet/depaware.txt index cb6b6996b7a87..a4eed2a13a338 100644 --- a/tsnet/depaware.txt +++ b/tsnet/depaware.txt @@ -6,77 +6,6 @@ tailscale.com/tsnet dependencies: (generated by github.com/tailscale/depaware) W đŸ’Ŗ github.com/alexbrainman/sspi from github.com/alexbrainman/sspi/internal/common+ W github.com/alexbrainman/sspi/internal/common from github.com/alexbrainman/sspi/negotiate W đŸ’Ŗ github.com/alexbrainman/sspi/negotiate from tailscale.com/net/tshttpproxy - github.com/aws/aws-sdk-go-v2/aws from github.com/aws/aws-sdk-go-v2/aws/defaults+ - github.com/aws/aws-sdk-go-v2/aws/defaults from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/aws/middleware from github.com/aws/aws-sdk-go-v2/aws/retry+ - github.com/aws/aws-sdk-go-v2/aws/protocol/query from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/aws-sdk-go-v2/aws/protocol/restjson from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/aws/protocol/xml from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/aws-sdk-go-v2/aws/ratelimit from github.com/aws/aws-sdk-go-v2/aws/retry - github.com/aws/aws-sdk-go-v2/aws/retry from github.com/aws/aws-sdk-go-v2/credentials/endpointcreds/internal/client+ - github.com/aws/aws-sdk-go-v2/aws/signer/internal/v4 from github.com/aws/aws-sdk-go-v2/aws/signer/v4 - github.com/aws/aws-sdk-go-v2/aws/signer/v4 from github.com/aws/aws-sdk-go-v2/internal/auth/smithy+ - github.com/aws/aws-sdk-go-v2/aws/transport/http from github.com/aws/aws-sdk-go-v2/config+ - github.com/aws/aws-sdk-go-v2/config from tailscale.com/wif - github.com/aws/aws-sdk-go-v2/credentials from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/credentials/ec2rolecreds from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/credentials/endpointcreds from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/credentials/endpointcreds/internal/client from github.com/aws/aws-sdk-go-v2/credentials/endpointcreds - github.com/aws/aws-sdk-go-v2/credentials/processcreds from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/credentials/ssocreds from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/credentials/stscreds from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/feature/ec2/imds from github.com/aws/aws-sdk-go-v2/config+ - github.com/aws/aws-sdk-go-v2/feature/ec2/imds/internal/config from github.com/aws/aws-sdk-go-v2/feature/ec2/imds - github.com/aws/aws-sdk-go-v2/internal/auth from github.com/aws/aws-sdk-go-v2/aws/signer/v4+ - github.com/aws/aws-sdk-go-v2/internal/auth/smithy from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/internal/configsources from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/internal/context from github.com/aws/aws-sdk-go-v2/aws/retry+ - github.com/aws/aws-sdk-go-v2/internal/endpoints from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/internal/endpoints/awsrulesfn from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 from github.com/aws/aws-sdk-go-v2/service/sso/internal/endpoints+ - github.com/aws/aws-sdk-go-v2/internal/ini from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/internal/middleware from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/internal/rand from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/aws-sdk-go-v2/internal/sdk from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/aws-sdk-go-v2/internal/sdkio from github.com/aws/aws-sdk-go-v2/credentials/processcreds - github.com/aws/aws-sdk-go-v2/internal/shareddefaults from github.com/aws/aws-sdk-go-v2/config+ - github.com/aws/aws-sdk-go-v2/internal/strings from github.com/aws/aws-sdk-go-v2/aws/signer/internal/v4 - github.com/aws/aws-sdk-go-v2/internal/sync/singleflight from github.com/aws/aws-sdk-go-v2/aws - github.com/aws/aws-sdk-go-v2/internal/timeconv from github.com/aws/aws-sdk-go-v2/aws/retry - github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/aws-sdk-go-v2/service/sso from github.com/aws/aws-sdk-go-v2/config+ - github.com/aws/aws-sdk-go-v2/service/sso/internal/endpoints from github.com/aws/aws-sdk-go-v2/service/sso - github.com/aws/aws-sdk-go-v2/service/sso/types from github.com/aws/aws-sdk-go-v2/service/sso - github.com/aws/aws-sdk-go-v2/service/ssooidc from github.com/aws/aws-sdk-go-v2/config+ - github.com/aws/aws-sdk-go-v2/service/ssooidc/internal/endpoints from github.com/aws/aws-sdk-go-v2/service/ssooidc - github.com/aws/aws-sdk-go-v2/service/ssooidc/types from github.com/aws/aws-sdk-go-v2/service/ssooidc - github.com/aws/aws-sdk-go-v2/service/sts from github.com/aws/aws-sdk-go-v2/config+ - github.com/aws/aws-sdk-go-v2/service/sts/internal/endpoints from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/aws-sdk-go-v2/service/sts/types from github.com/aws/aws-sdk-go-v2/credentials/stscreds+ - github.com/aws/smithy-go from github.com/aws/aws-sdk-go-v2/aws/protocol/restjson+ - github.com/aws/smithy-go/auth from github.com/aws/aws-sdk-go-v2/internal/auth+ - github.com/aws/smithy-go/auth/bearer from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/smithy-go/context from github.com/aws/smithy-go/auth/bearer - github.com/aws/smithy-go/document from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/smithy-go/encoding from github.com/aws/smithy-go/encoding/json+ - github.com/aws/smithy-go/encoding/httpbinding from github.com/aws/aws-sdk-go-v2/aws/protocol/query+ - github.com/aws/smithy-go/encoding/json from github.com/aws/aws-sdk-go-v2/service/ssooidc - github.com/aws/smithy-go/encoding/xml from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/smithy-go/endpoints from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/smithy-go/endpoints/private/rulesfn from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/smithy-go/internal/sync/singleflight from github.com/aws/smithy-go/auth/bearer - github.com/aws/smithy-go/io from github.com/aws/aws-sdk-go-v2/feature/ec2/imds+ - github.com/aws/smithy-go/logging from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/smithy-go/metrics from github.com/aws/aws-sdk-go-v2/aws/retry+ - github.com/aws/smithy-go/middleware from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/smithy-go/private/requestcompression from github.com/aws/aws-sdk-go-v2/config - github.com/aws/smithy-go/ptr from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/smithy-go/rand from github.com/aws/aws-sdk-go-v2/aws/middleware - github.com/aws/smithy-go/time from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/smithy-go/tracing from github.com/aws/aws-sdk-go-v2/aws/middleware+ - github.com/aws/smithy-go/transport/http from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/smithy-go/transport/http/internal/io from github.com/aws/smithy-go/transport/http LDW github.com/coder/websocket from tailscale.com/util/eventbus LDW github.com/coder/websocket/internal/errd from github.com/coder/websocket LDW github.com/coder/websocket/internal/util from github.com/coder/websocket @@ -105,7 +34,6 @@ tailscale.com/tsnet dependencies: (generated by github.com/tailscale/depaware) L đŸ’Ŗ github.com/godbus/dbus/v5 from tailscale.com/net/dns github.com/golang/groupcache/lru from tailscale.com/net/dnscache github.com/google/btree from gvisor.dev/gvisor/pkg/tcpip/transport/tcp - DI github.com/google/uuid from github.com/prometheus-community/pro-bing github.com/hdevalence/ed25519consensus from tailscale.com/tka github.com/huin/goupnp from github.com/huin/goupnp/dcps/internetgateway2+ github.com/huin/goupnp/dcps/internetgateway2 from tailscale.com/net/portmapper @@ -128,9 +56,8 @@ tailscale.com/tsnet dependencies: (generated by github.com/tailscale/depaware) LA đŸ’Ŗ github.com/mdlayher/socket from github.com/mdlayher/netlink+ LDW đŸ’Ŗ github.com/mitchellh/go-ps from tailscale.com/safesocket github.com/pires/go-proxyproto from tailscale.com/ipn/ipnlocal - DI github.com/prometheus-community/pro-bing from tailscale.com/wgengine/netstack L đŸ’Ŗ github.com/safchain/ethtool from tailscale.com/net/netkernelconf - W đŸ’Ŗ github.com/tailscale/certstore from tailscale.com/control/controlclient + DW đŸ’Ŗ github.com/tailscale/certstore from tailscale.com/control/controlclient W đŸ’Ŗ github.com/tailscale/go-winio from tailscale.com/safesocket W đŸ’Ŗ github.com/tailscale/go-winio/internal/fs from github.com/tailscale/go-winio W đŸ’Ŗ github.com/tailscale/go-winio/internal/socket from github.com/tailscale/go-winio @@ -219,11 +146,9 @@ tailscale.com/tsnet dependencies: (generated by github.com/tailscale/depaware) tailscale.com/feature/buildfeatures from tailscale.com/wgengine/magicsock+ tailscale.com/feature/c2n from tailscale.com/tsnet tailscale.com/feature/condlite/expvar from tailscale.com/wgengine/magicsock - tailscale.com/feature/condregister/identityfederation from tailscale.com/tsnet tailscale.com/feature/condregister/oauthkey from tailscale.com/tsnet tailscale.com/feature/condregister/portmapper from tailscale.com/tsnet tailscale.com/feature/condregister/useproxy from tailscale.com/tsnet - tailscale.com/feature/identityfederation from tailscale.com/feature/condregister/identityfederation tailscale.com/feature/oauthkey from tailscale.com/feature/condregister/oauthkey tailscale.com/feature/portmapper from tailscale.com/feature/condregister/portmapper tailscale.com/feature/syspolicy from tailscale.com/logpolicy @@ -235,7 +160,7 @@ tailscale.com/tsnet dependencies: (generated by github.com/tailscale/depaware) tailscale.com/ipn from tailscale.com/client/local+ tailscale.com/ipn/conffile from tailscale.com/ipn/ipnlocal+ đŸ’Ŗ tailscale.com/ipn/ipnauth from tailscale.com/ipn/ipnext+ - tailscale.com/ipn/ipnext from tailscale.com/ipn/ipnlocal + tailscale.com/ipn/ipnext from tailscale.com/ipn/ipnlocal+ tailscale.com/ipn/ipnlocal from tailscale.com/ipn/localapi+ tailscale.com/ipn/ipnlocal/netmapcache from tailscale.com/ipn/ipnlocal tailscale.com/ipn/ipnstate from tailscale.com/client/local+ @@ -304,12 +229,13 @@ tailscale.com/tsnet dependencies: (generated by github.com/tailscale/depaware) tailscale.com/tstime from tailscale.com/control/controlclient+ tailscale.com/tstime/mono from tailscale.com/net/tstun+ tailscale.com/tstime/rate from tailscale.com/wgengine/filter - LDW tailscale.com/tsweb from tailscale.com/util/eventbus + LDW tailscale.com/tsweb from tailscale.com/util/eventbus+ tailscale.com/tsweb/varz from tailscale.com/tsweb+ tailscale.com/types/appctype from tailscale.com/ipn/ipnlocal+ tailscale.com/types/bools from tailscale.com/tsnet+ tailscale.com/types/dnstype from tailscale.com/client/local+ tailscale.com/types/empty from tailscale.com/ipn+ + tailscale.com/types/events from tailscale.com/control/controlclient+ tailscale.com/types/ipproto from tailscale.com/ipn+ tailscale.com/types/key from tailscale.com/client/local+ tailscale.com/types/lazy from tailscale.com/hostinfo+ @@ -323,12 +249,12 @@ tailscale.com/tsnet dependencies: (generated by github.com/tailscale/depaware) tailscale.com/types/opt from tailscale.com/control/controlknobs+ tailscale.com/types/persist from tailscale.com/control/controlclient+ tailscale.com/types/preftype from tailscale.com/ipn+ - tailscale.com/types/ptr from tailscale.com/control/controlclient+ tailscale.com/types/result from tailscale.com/util/lineiter tailscale.com/types/structs from tailscale.com/control/controlclient+ tailscale.com/types/tkatype from tailscale.com/client/local+ tailscale.com/types/views from tailscale.com/appc+ tailscale.com/util/backoff from tailscale.com/control/controlclient+ + tailscale.com/util/bufiox from tailscale.com/types/key tailscale.com/util/checkchange from tailscale.com/ipn/ipnlocal+ tailscale.com/util/cibuild from tailscale.com/health+ tailscale.com/util/clientmetric from tailscale.com/appc+ @@ -393,12 +319,10 @@ tailscale.com/tsnet dependencies: (generated by github.com/tailscale/depaware) tailscale.com/wgengine/wgcfg/nmcfg from tailscale.com/ipn/ipnlocal đŸ’Ŗ tailscale.com/wgengine/wgint from tailscale.com/wgengine+ tailscale.com/wgengine/wglog from tailscale.com/wgengine - tailscale.com/wif from tailscale.com/feature/identityfederation golang.org/x/crypto/argon2 from tailscale.com/tka golang.org/x/crypto/blake2b from golang.org/x/crypto/argon2+ golang.org/x/crypto/blake2s from github.com/tailscale/wireguard-go/device+ - LD golang.org/x/crypto/blowfish from golang.org/x/crypto/ssh/internal/bcrypt_pbkdf - golang.org/x/crypto/chacha20 from golang.org/x/crypto/chacha20poly1305+ + golang.org/x/crypto/chacha20 from golang.org/x/crypto/chacha20poly1305 golang.org/x/crypto/chacha20poly1305 from github.com/tailscale/wireguard-go/device+ golang.org/x/crypto/curve25519 from github.com/tailscale/wireguard-go/device+ golang.org/x/crypto/hkdf from tailscale.com/control/controlbase @@ -408,24 +332,22 @@ tailscale.com/tsnet dependencies: (generated by github.com/tailscale/depaware) golang.org/x/crypto/nacl/secretbox from golang.org/x/crypto/nacl/box golang.org/x/crypto/poly1305 from github.com/tailscale/wireguard-go/device golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+ - LD golang.org/x/crypto/ssh from tailscale.com/ipn/ipnlocal - LD golang.org/x/crypto/ssh/internal/bcrypt_pbkdf from golang.org/x/crypto/ssh golang.org/x/exp/constraints from tailscale.com/tsweb/varz+ golang.org/x/exp/maps from tailscale.com/ipn/store/mem+ golang.org/x/net/bpf from github.com/mdlayher/netlink+ golang.org/x/net/dns/dnsmessage from tailscale.com/appc+ golang.org/x/net/http/httpguts from tailscale.com/ipn/ipnlocal golang.org/x/net/http/httpproxy from tailscale.com/net/tshttpproxy - golang.org/x/net/icmp from github.com/prometheus-community/pro-bing+ + golang.org/x/net/icmp from tailscale.com/net/ping golang.org/x/net/idna from golang.org/x/net/http/httpguts+ golang.org/x/net/internal/iana from golang.org/x/net/icmp+ - golang.org/x/net/internal/socket from golang.org/x/net/icmp+ + golang.org/x/net/internal/socket from golang.org/x/net/ipv4+ LDW golang.org/x/net/internal/socks from golang.org/x/net/proxy - golang.org/x/net/ipv4 from github.com/prometheus-community/pro-bing+ - golang.org/x/net/ipv6 from github.com/prometheus-community/pro-bing+ + golang.org/x/net/ipv4 from github.com/tailscale/wireguard-go/conn+ + golang.org/x/net/ipv6 from github.com/tailscale/wireguard-go/conn+ LDW golang.org/x/net/proxy from tailscale.com/net/netns DI golang.org/x/net/route from tailscale.com/net/netmon+ - golang.org/x/oauth2 from golang.org/x/oauth2/clientcredentials+ + golang.org/x/oauth2 from golang.org/x/oauth2/clientcredentials golang.org/x/oauth2/clientcredentials from tailscale.com/feature/oauthkey golang.org/x/oauth2/internal from golang.org/x/oauth2+ golang.org/x/sync/errgroup from github.com/mdlayher/socket+ @@ -470,7 +392,7 @@ tailscale.com/tsnet dependencies: (generated by github.com/tailscale/depaware) crypto/aes from crypto/tls+ crypto/cipher from crypto/aes+ crypto/des from crypto/tls+ - crypto/dsa from crypto/x509+ + crypto/dsa from crypto/x509 crypto/ecdh from crypto/ecdsa+ crypto/ecdsa from crypto/tls+ crypto/ed25519 from crypto/tls+ @@ -519,21 +441,20 @@ tailscale.com/tsnet dependencies: (generated by github.com/tailscale/depaware) crypto/internal/randutil from crypto/internal/rand crypto/internal/sysrand from crypto/internal/fips140/drbg crypto/md5 from crypto/tls+ - crypto/mlkem from golang.org/x/crypto/ssh+ + crypto/mlkem from crypto/hpke+ crypto/rand from crypto/ed25519+ - crypto/rc4 from crypto/tls+ + crypto/rc4 from crypto/tls crypto/rsa from crypto/tls+ crypto/sha1 from crypto/tls+ crypto/sha256 from crypto/tls+ crypto/sha3 from crypto/internal/fips140hash+ crypto/sha512 from crypto/ecdsa+ crypto/subtle from crypto/cipher+ - crypto/tls from github.com/prometheus-community/pro-bing+ + crypto/tls from net/http+ crypto/tls/internal/fips140tls from crypto/tls crypto/x509 from crypto/tls+ DI crypto/x509/internal/macos from crypto/x509 crypto/x509/pkix from crypto/x509+ - DI database/sql/driver from github.com/google/uuid W debug/dwarf from debug/pe W debug/pe from github.com/dblohm7/wingoes/pe embed from github.com/tailscale/web-client-prebuilt+ @@ -622,7 +543,7 @@ tailscale.com/tsnet dependencies: (generated by github.com/tailscale/depaware) mime/quotedprintable from mime/multipart net from crypto/tls+ net/http from expvar+ - net/http/httptrace from github.com/prometheus-community/pro-bing+ + net/http/httptrace from net/http+ net/http/httputil from tailscale.com/client/web+ net/http/internal from net/http+ net/http/internal/ascii from net/http+ @@ -636,7 +557,7 @@ tailscale.com/tsnet dependencies: (generated by github.com/tailscale/depaware) os/user from github.com/godbus/dbus/v5+ path from debug/dwarf+ path/filepath from crypto/x509+ - reflect from database/sql/driver+ + reflect from encoding/asn1+ regexp from github.com/huin/goupnp/httpu+ regexp/syntax from regexp runtime from crypto/internal/fips140+ diff --git a/tsnet/example/tshello/README.md b/tsnet/example/tshello/README.md new file mode 100644 index 0000000000000..5d9d81829fd97 --- /dev/null +++ b/tsnet/example/tshello/README.md @@ -0,0 +1,5 @@ + + +# tshello + +The tshello server demonstrates how to use Tailscale as a library. diff --git a/tsnet/example/tsnet-funnel/README.md b/tsnet/example/tsnet-funnel/README.md new file mode 100644 index 0000000000000..2b3031bed66c8 --- /dev/null +++ b/tsnet/example/tsnet-funnel/README.md @@ -0,0 +1,9 @@ + + +# tsnet-funnel + +The tsnet-funnel server demonstrates how to use tsnet with Funnel. + +To use it, generate an auth key from the Tailscale admin panel and run the demo with the key: + + TS_AUTHKEY= go run tsnet-funnel.go diff --git a/tsnet/example/tsnet-http-client/README.md b/tsnet/example/tsnet-http-client/README.md new file mode 100644 index 0000000000000..24aba97c8bb18 --- /dev/null +++ b/tsnet/example/tsnet-http-client/README.md @@ -0,0 +1,5 @@ + + +# tsnet-http-client + +The tshello server demonstrates how to use Tailscale as a library. diff --git a/tsnet/example/tsnet-services/README.md b/tsnet/example/tsnet-services/README.md new file mode 100644 index 0000000000000..18bc072d782f6 --- /dev/null +++ b/tsnet/example/tsnet-services/README.md @@ -0,0 +1,32 @@ + + +# tsnet-services + +The tsnet-services example demonstrates how to use tsnet with Services. + +To run this example yourself: + + 1. Add access controls which (i) define a new ACL tag, (ii) allow the demo node to host the Service, and (iii) allow peers on the tailnet to reach the Service. A sample ACL policy is provided below. + 2. [Generate an auth key](https://tailscale.com/kb/1085/auth-keys#generate-an-auth-key) using the Tailscale admin panel. When doing so, add your new tag to your key (Service hosts must be tagged nodes). + 3. [Define a Service](https://tailscale.com/kb/1552/tailscale-services#step-1-define-a-tailscale-service). For the purposes of this demo, it must be defined to listen on TCP port 443. Note that you only need to follow Step 1 in the linked document. + 4. Run the demo on the command line (step 4 command shown below). + +Command for step 4: + + TS_AUTHKEY= go run tsnet-services.go -service + +The following is a sample ACL policy for step 1: + + "tagOwners": { + "tag:tsnet-demo-host": ["autogroup:member"], + }, + "autoApprovers": { + "services": { + "svc:tsnet-demo": ["tag:tsnet-demo-host"], + }, + }, + "grants": [ + "src": ["*"], + "dst": ["svc:tsnet-demo"], + "ip": ["*"], + ], diff --git a/tsnet/example/tsnet-services/tsnet-services.go b/tsnet/example/tsnet-services/tsnet-services.go index d72fd68fd412a..4604e8d3fbbce 100644 --- a/tsnet/example/tsnet-services/tsnet-services.go +++ b/tsnet/example/tsnet-services/tsnet-services.go @@ -8,17 +8,16 @@ // 1. Add access controls which (i) define a new ACL tag, (ii) allow the demo // node to host the Service, and (iii) allow peers on the tailnet to reach // the Service. A sample ACL policy is provided below. -// // 2. [Generate an auth key] using the Tailscale admin panel. When doing so, add // your new tag to your key (Service hosts must be tagged nodes). -// // 3. [Define a Service]. For the purposes of this demo, it must be defined to // listen on TCP port 443. Note that you only need to follow Step 1 in the // linked document. +// 4. Run the demo on the command line (step 4 command shown below). // -// 4. Run the demo on the command line: +// Command for step 4: // -// TS_AUTHKEY= go run tsnet-services.go -service +// TS_AUTHKEY= go run tsnet-services.go -service // // The following is a sample ACL policy for step 1: // diff --git a/tsnet/example/web-client/README.md b/tsnet/example/web-client/README.md new file mode 100644 index 0000000000000..6b4c42235983a --- /dev/null +++ b/tsnet/example/web-client/README.md @@ -0,0 +1,5 @@ + + +# web-client + +The web-client command demonstrates serving the Tailscale web client over tsnet. diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index 776854e227926..eb72d28d3f3e0 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -1,7 +1,139 @@ // Copyright (c) Tailscale Inc & contributors // SPDX-License-Identifier: BSD-3-Clause -// Package tsnet provides Tailscale as a library. +// Package tsnet embeds a Tailscale node directly into a Go program, +// allowing it to join a tailnet and accept or dial connections without +// running a separate tailscaled daemon or requiring any system-level +// configuration. +// +// # Overview +// +// Normally, Tailscale runs as a background system service (tailscaled) +// that manages a virtual network interface for the whole machine. tsnet +// takes a different approach: it runs a fully self-contained Tailscale +// node inside your process using a userspace TCP/IP stack (gVisor). +// This means: +// +// - No root privileges required. +// - No system daemons to install or manage. +// - Multiple independent Tailscale nodes can run within a single binary. +// - The node's [Tailscale identity] and state are stored in a directory you control. +// +// The core type is [Server], which represents one embedded Tailscale +// node. Calling [Server.Listen] or [Server.Dial] routes traffic +// exclusively over the tailnet. The standard library's [net.Listener] +// and [net.Conn] interfaces are returned, so any existing Go HTTP +// server, gRPC server, or other net-based code works without +// modification. +// +// # Usage +// +// import "tailscale.com/tsnet" +// +// s := &tsnet.Server{ +// Hostname: "my-service", +// AuthKey: os.Getenv("TS_AUTHKEY"), +// } +// defer s.Close() +// +// ln, err := s.Listen("tcp", ":80") +// if err != nil { +// log.Fatal(err) +// } +// log.Fatal(http.Serve(ln, myHandler)) +// +// On first run, if no [Server.AuthKey] is provided and the node is not +// already enrolled, the server logs an authentication URL. Open it in a +// browser to add the node to your tailnet. +// +// # Authentication +// +// A [Server] authenticates using, in order of precedence: +// +// 1. [Server.AuthKey]. +// +// 2. The TS_AUTHKEY environment variable. +// +// 3. The TS_AUTH_KEY environment variable. +// +// 4. An OAuth client secret ([Server.ClientSecret] or TS_CLIENT_SECRET), +// used to mint an auth key. +// +// 5. Workload identity federation ([Server.ClientID] plus +// [Server.IDToken] or [Server.Audience]). Available only if the +// program imports the feature: +// +// import _ "tailscale.com/feature/identityfederation" +// +// The feature is not linked by default to keep the AWS SDK and +// other cloud-provider dependencies out of programs that don't +// use workload identity federation. +// +// 6. An interactive login URL printed to [Server.UserLogf]. +// +// If the node is already enrolled (state found in [Server.Store]), the +// auth key is ignored unless TSNET_FORCE_LOGIN=1 is set. +// +// # Identifying callers +// +// Use the WhoIs method on the client returned by [Server.LocalClient] +// to identify who is making a request: +// +// lc, _ := srv.LocalClient() +// http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// who, err := lc.WhoIs(r.Context(), r.RemoteAddr) +// if err != nil { +// http.Error(w, err.Error(), 500) +// return +// } +// fmt.Fprintf(w, "Hello, %s!", who.UserProfile.LoginName) +// })) +// +// # Tailscale Funnel +// +// [Server.ListenFunnel] exposes your service on the public internet. +// [Tailscale Funnel] currently supports TCP on ports 443, 8443, and +// 10000. HTTPS must be enabled in the Tailscale admin console. +// +// ln, err := srv.ListenFunnel("tcp", ":443") +// // ln is a TLS listener; connections can come from anywhere on the +// // internet as well as from your tailnet. +// +// // To restrict to public traffic only: +// ln, err = srv.ListenFunnel("tcp", ":443", tsnet.FunnelOnly()) +// +// # Tailscale Services +// +// [Server.ListenService] advertises the node as a host for a named +// [Tailscale Service]. The node must use a tag-based identity. To +// advertise multiple ports, call ListenService once per port. +// +// srv.AdvertiseTags = []string{"tag:myservice"} +// +// ln, err := srv.ListenService("svc:my-service", tsnet.ServiceModeHTTP{ +// HTTPS: true, +// Port: 443, +// }) +// log.Printf("Listening on https://%s", ln.FQDN) +// +// # Running multiple nodes in one process +// +// Each [Server] instance is an independent node. Give each a unique +// [Server.Dir] and [Server.Hostname]: +// +// for _, name := range []string{"frontend", "backend"} { +// srv := &tsnet.Server{ +// Hostname: name, +// Dir: filepath.Join(baseDir, name), +// AuthKey: os.Getenv("TS_AUTHKEY"), +// Ephemeral: true, +// } +// srv.Start() +// } +// +// [Tailscale identity]: https://tailscale.com/docs/concepts/tailscale-identity +// [Tailscale Funnel]: https://tailscale.com/docs/features/tailscale-funnel +// [Tailscale Service]: https://tailscale.com/docs/features/tailscale-services package tsnet import ( @@ -31,7 +163,6 @@ import ( "tailscale.com/control/controlclient" "tailscale.com/envknob" _ "tailscale.com/feature/c2n" - _ "tailscale.com/feature/condregister/identityfederation" _ "tailscale.com/feature/condregister/oauthkey" _ "tailscale.com/feature/condregister/portmapper" _ "tailscale.com/feature/condregister/useproxy" @@ -59,6 +190,7 @@ import ( "tailscale.com/types/logger" "tailscale.com/types/logid" "tailscale.com/types/nettype" + "tailscale.com/types/views" "tailscale.com/util/clientmetric" "tailscale.com/util/mak" "tailscale.com/util/set" @@ -151,6 +283,8 @@ type Server struct { // ControlURL optionally specifies the coordination server URL. // If empty, the Tailscale default is used. + // If empty, it defaults to the TS_CONTROL_URL environment variable. + // If that is also empty, the Tailscale default is used. ControlURL string // RunWebClient, if true, runs a client for managing this node over @@ -173,32 +307,33 @@ type Server struct { // This field must be set before calling Start. Tun tun.Device - initOnce sync.Once - initErr error - lb *ipnlocal.LocalBackend - sys *tsd.System - netstack *netstack.Impl - netMon *netmon.Monitor - rootPath string // the state directory - hostname string - shutdownCtx context.Context - shutdownCancel context.CancelFunc - proxyCred string // SOCKS5 proxy auth for loopbackListener - localAPICred string // basic auth password for loopbackListener - loopbackListener net.Listener // optional loopback for localapi and proxies - localAPIListener net.Listener // in-memory, used by localClient - localClient *local.Client // in-memory - localAPIServer *http.Server - resetServeConfigOnce sync.Once - logbuffer *filch.Filch - logtail *logtail.Logger - logid logid.PublicID + initOnce sync.Once + initErr error + lb *ipnlocal.LocalBackend + sys *tsd.System + netstack *netstack.Impl + netMon *netmon.Monitor + rootPath string // the state directory + hostname string + shutdownCtx context.Context + shutdownCancel context.CancelFunc + proxyCred string // SOCKS5 proxy auth for loopbackListener + localAPICred string // basic auth password for loopbackListener + loopbackListener net.Listener // optional loopback for localapi and proxies + localAPIListener net.Listener // in-memory, used by localClient + localClient *local.Client // in-memory + localAPIServer *http.Server + resetServeStateOnce sync.Once + logbuffer *filch.Filch + logtail *logtail.Logger + logid logid.PublicID mu sync.Mutex listeners map[listenKey]*listener nextEphemeralPort uint16 // next port to try in ephemeral range; 0 means use ephemeralPortFirst fallbackTCPHandlers set.HandleSet[FallbackTCPHandler] dialer *tsdial.Dialer + advertisedServices map[tailcfg.ServiceName]int closeOnce sync.Once } @@ -279,6 +414,19 @@ func (s *Server) LocalClient() (*local.Client, error) { return s.localClient, nil } +// TestHooks are hooks meant for internal-testing only; they're not stable +// or documented, intentionally. +var TestHooks testHooks + +type testHooks struct{} + +// LocalBackend returns the [ipnlocal.LocalBackend] backing s. It panics +// outside of tests. +func (testHooks) LocalBackend(s *Server) *ipnlocal.LocalBackend { + testenv.AssertInTest() + return s.lb +} + // Loopback starts a routing server on a loopback address. // // The server has multiple functions. @@ -413,15 +561,27 @@ func (s *Server) Up(ctx context.Context) (*ipnstate.Status, error) { return nil, errors.New("tsnet.Up: running, but no ip") } - // The first time Up is run, clear the persisted serve config. - // We do this to prevent messy interactions with stale config in - // the face of code changes. - var srvResetErr error - s.resetServeConfigOnce.Do(func() { - srvResetErr = lc.SetServeConfig(ctx, new(ipn.ServeConfig)) + // The first time Up is run, clear the persisted serve config + // and Service advertisements. We do this to prevent messy + // interactions with stale config in the face of code changes. + var srvCfgErr error + var svcAdErr error + s.resetServeStateOnce.Do(func() { + if err := lc.SetServeConfig(ctx, new(ipn.ServeConfig)); err != nil { + srvCfgErr = fmt.Errorf("clearing serve config: %w", err) + } + _, err := s.lb.EditPrefs(&ipn.MaskedPrefs{ + AdvertiseServicesSet: true, + Prefs: ipn.Prefs{ + AdvertiseServices: []string{}, + }, + }) + if err != nil { + svcAdErr = fmt.Errorf("clearing Service advertisements: %w", err) + } }) - if srvResetErr != nil { - return nil, fmt.Errorf("tsnet.Up: clearing serve config: %w", err) + if err := errors.Join(srvCfgErr, svcAdErr); err != nil { + return nil, fmt.Errorf("tsnet.Up: %w", err) } return status, nil @@ -466,9 +626,7 @@ func (s *Server) close() { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() + wg.Go(func() { // Perform a best-effort final flush. if s.logtail != nil { s.logtail.Shutdown(ctx) @@ -476,14 +634,12 @@ func (s *Server) close() { if s.logbuffer != nil { s.logbuffer.Close() } - }() - wg.Add(1) - go func() { - defer wg.Done() + }) + wg.Go(func() { if s.localAPIServer != nil { s.localAPIServer.Shutdown(ctx) } - }() + }) if s.shutdownCancel != nil { s.shutdownCancel() @@ -523,7 +679,7 @@ func (s *Server) doInit() { // Server. // If the server is not running, it returns nil. func (s *Server) CertDomains() []string { - nm := s.lb.NetMap() + nm := s.lb.NetMapNoPeers() if nm == nil { return nil } @@ -534,7 +690,7 @@ func (s *Server) CertDomains() []string { // has not yet joined a tailnet or is otherwise unaware of its own IP addresses, // the returned ip4, ip6 will be !netip.IsValid(). func (s *Server) TailscaleIPs() (ip4, ip6 netip.Addr) { - nm := s.lb.NetMap() + nm := s.lb.NetMapNoPeers() if nm == nil { return } @@ -572,6 +728,13 @@ func (s *Server) getAuthKey() string { return os.Getenv("TS_AUTH_KEY") } +func (s *Server) getControlURL() string { + if v := s.ControlURL; v != "" { + return v + } + return os.Getenv("TS_CONTROL_URL") +} + func (s *Server) getClientSecret() string { if v := s.ClientSecret; v != "" { return v @@ -613,10 +776,15 @@ func (s *Server) start() (reterr error) { // directory and hostname when they're not supplied. But we can fall // back to "tsnet" as well. exe = "tsnet" - case "ios": + case "ios", "darwin": // When compiled as a framework (via TailscaleKit in libtailscale), - // os.Executable() returns an error, so fall back to "tsnet" there - // too. + // os.Executable() returns an error on iOS. The same failure occurs + // on macOS (darwin) when the framework is loaded in a process + // launched by a debugger or certain host environments (e.g. Xcode), + // where the OS does not expose a resolvable executable path to the + // embedded Go runtime. Fall back to "tsnet" in both cases — the + // value is only used as a default hostname/directory when neither + // Server.Hostname nor Server.Dir is set. exe = "tsnet" default: return err @@ -686,6 +854,7 @@ func (s *Server) start() (reterr error) { SetSubsystem: sys.Set, ControlKnobs: sys.ControlKnobs(), HealthTracker: sys.HealthTracker.Get(), + ExtraRootCAs: sys.ExtraRootCAs, Metrics: sys.UserMetricsRegistry(), }) if err != nil { @@ -773,7 +942,7 @@ func (s *Server) start() (reterr error) { prefs := ipn.NewPrefs() prefs.Hostname = s.hostname prefs.WantRunning = true - prefs.ControlURL = s.ControlURL + prefs.ControlURL = s.getControlURL() prefs.RunWebClient = s.RunWebClient prefs.AdvertiseTags = s.AdvertiseTags authKey, err := s.resolveAuthKey() @@ -862,7 +1031,7 @@ func (s *Server) resolveAuthKey() (string, error) { return "", fmt.Errorf("audience for workload identity federation found, but client ID is empty") } } - authKey, err = resolveViaWIF(s.shutdownCtx, s.ControlURL, clientID, idToken, audience, s.AdvertiseTags) + authKey, err = resolveViaWIF(s.shutdownCtx, s.getControlURL(), clientID, idToken, audience, s.AdvertiseTags) if err != nil { return "", err } @@ -936,6 +1105,22 @@ func (s *Server) logf(format string, a ...any) { // printAuthURLLoop loops once every few seconds while the server is still running and // is in NeedsLogin state, printing out the auth URL. func (s *Server) printAuthURLLoop() { + ctx, cancel := context.WithCancel(s.shutdownCtx) + defer cancel() + stateCh := make(chan struct{}, 1) + go s.lb.WatchNotifications(ctx, ipn.NotifyInitialState, nil, func(n *ipn.Notify) (keepGoing bool) { + if n.State == nil { + return true + } + + // No need to block, we only want to make sure the loop below is not + // blocking on time.After if there's a new state available. + select { + case stateCh <- struct{}{}: + default: + } + return true + }) for { if s.shutdownCtx.Err() != nil { return @@ -950,6 +1135,7 @@ func (s *Server) printAuthURLLoop() { } select { case <-time.After(5 * time.Second): + case <-stateCh: case <-s.shutdownCtx.Done(): return } @@ -1447,6 +1633,13 @@ type ServiceListener struct { // FQDN is the fully-qualifed domain name of this Service. FQDN string + + // Used by Close. + closeOnce sync.Once + closeErr error // written to during execution of closeOnce, read by Close() + s *Server // read and written to during execution of closeOnce + svcName tailcfg.ServiceName // read during execution of closeOnce + mode ServiceMode // read during execution of closeOnce } // Addr returns the listener's network address. This will be the Service's @@ -1454,16 +1647,142 @@ type ServiceListener struct { // // A hostname is not truly a network address, but Services listen on multiple // addresses (the IPv4 and IPv6 virtual IPs). -func (sl ServiceListener) Addr() net.Addr { +func (sl *ServiceListener) Addr() net.Addr { return sl.addr } +// cleanServeConfig cleans serve config changes made to support this listener. +// This should only be called by Close. +func (sl *ServiceListener) cleanServeConfig() error { + sc, etag, err := sl.s.lb.ServeConfigETag() + if err != nil { + return fmt.Errorf("fetching current config: %w", err) + } + if !sc.Valid() || !sc.Services().Contains(sl.svcName) { + return nil + } + srvConfig := sc.AsStruct() + svcConfig := srvConfig.Services[sl.svcName] + switch m := sl.mode.(type) { + case ServiceModeTCP: + delete(svcConfig.TCP, m.Port) + case ServiceModeHTTP: + hp := net.JoinHostPort(sl.FQDN, strconv.Itoa(int(m.Port))) + delete(svcConfig.Web, ipn.HostPort(hp)) + delete(svcConfig.TCP, m.Port) + default: + return fmt.Errorf("unexpected ServiceMode %T", sl.mode) + } + if err := sl.s.lb.SetServeConfig(srvConfig, etag); err != nil { + return fmt.Errorf("setting config: %w", err) + } + return nil +} + +// Close closes the listener and clears state related to hosting the Service. +// Behavior is undefined after the [Server] has been closed. +func (sl *ServiceListener) Close() error { + // We should only clean up state once. Otherwise we can stomp on state + // created by new listeners. + sl.closeOnce.Do(func() { + sl.s.mu.Lock() + defer sl.s.mu.Unlock() + + // Two pieces of state we need to clear: + // 1. The Service advertisement pref + // 2. Artifacts in the serve config + // Then we can close the listener. + + var adErr error + if err := sl.s.decrementServiceAdvertisementLocked(sl.svcName); err != nil { + adErr = fmt.Errorf("managing Service advertisements: %w", err) + } + + var srvCfgErr error + if err := sl.cleanServeConfig(); err != nil { + srvCfgErr = fmt.Errorf("cleaning config changes: %w", err) + } + + sl.closeErr = errors.Join(sl.Listener.Close(), adErr, srvCfgErr) + }) + return sl.closeErr +} + // ErrUntaggedServiceHost is returned by ListenService when run on a node // without any ACL tags. A node must use a tag-based identity to act as a // Service host. For more information, see: // https://tailscale.com/kb/1552/tailscale-services#prerequisites var ErrUntaggedServiceHost = errors.New("service hosts must be tagged nodes") +// advertiseService ensures the Service is advertised by this node. +func (s *Server) advertiseService(name tailcfg.ServiceName) error { + s.mu.Lock() + defer s.mu.Unlock() + + advertised := s.lb.Prefs().AdvertiseServices() + if !views.SliceContains(advertised, name.String()) { + newAdvertised := make([]string, 0, advertised.Len()+1) + newAdvertised = advertised.AppendTo(newAdvertised) + newAdvertised = append(newAdvertised, name.String()) + _, err := s.lb.EditPrefs(&ipn.MaskedPrefs{ + AdvertiseServicesSet: true, + Prefs: ipn.Prefs{ + AdvertiseServices: newAdvertised, + }, + }) + if err != nil { + return err + } + } + mak.Set(&s.advertisedServices, name, s.advertisedServices[name]+1) + return nil +} + +// decrementServiceAdvertisement decrements the count of listeners this node has +// advertising the Service. Advertisement of the Service will be withdrawn if +// the count hits zero. It is an error to call this function when the Service is +// not being advertised by this node. +func (s *Server) decrementServiceAdvertisementLocked(name tailcfg.ServiceName) error { + cleanAdvertisement := func() error { + delete(s.advertisedServices, name) + advertised := s.lb.Prefs().AdvertiseServices() + if !views.SliceContains(advertised, name.String()) { + return nil + } + newAdvertised := make([]string, 0, advertised.Len()-1) + for _, svc := range advertised.All() { + if svc == name.String() { + continue + } + newAdvertised = append(newAdvertised, svc) + } + _, err := s.lb.EditPrefs(&ipn.MaskedPrefs{ + AdvertiseServicesSet: true, + Prefs: ipn.Prefs{ + AdvertiseServices: newAdvertised, + }, + }) + return err + } + + if s.advertisedServices[name] <= 0 { + advertisements := s.advertisedServices[name] + // We somehow mismatched increments and decrements. Clear current + // advertisements and surface the mismatch as an error. + return errors.Join( + cleanAdvertisement(), + fmt.Errorf("service decrement requested with %d advertisements", advertisements), + ) + } + s.advertisedServices[name]-- + if s.advertisedServices[name] > 0 { + // If there are still listeners advertising the Service, then there's + // nothing more for us to do. + return nil + } + return cleanAdvertisement() +} + // ListenService creates a network listener for a Tailscale Service. This will // advertise this node as hosting the Service. Note that: // - Approval must still be granted by an admin or by ACL auto-approval rules. @@ -1476,13 +1795,22 @@ var ErrUntaggedServiceHost = errors.New("service hosts must be tagged nodes") // // This function will start the server if it is not already started. func (s *Server) ListenService(name string, mode ServiceMode) (*ServiceListener, error) { - if err := tailcfg.ServiceName(name).Validate(); err != nil { + svcName := tailcfg.ServiceName(name) + if err := svcName.Validate(); err != nil { return nil, err } if mode == nil { return nil, errors.New("mode may not be nil") } - svcName := name + + // We collect cleanup tasks as we go and execute these on error. If we make + // it to the end we abandon these cleanup tasks by setting onError to nil. + var onError []func() + defer func() { + for _, f := range onError { + f() + } + }() // TODO(hwh33,tailscale/corp#35859): support TUN mode @@ -1497,31 +1825,25 @@ func (s *Server) ListenService(name string, mode ServiceMode) (*ServiceListener, return nil, ErrUntaggedServiceHost } - advertisedServices := s.lb.Prefs().AdvertiseServices().AsSlice() - if !slices.Contains(advertisedServices, svcName) { - // TODO(hwh33,tailscale/corp#35860): clean these prefs up when (a) we - // exit early due to error or (b) when the returned listener is closed. - _, err = s.lb.EditPrefs(&ipn.MaskedPrefs{ - AdvertiseServicesSet: true, - Prefs: ipn.Prefs{ - AdvertiseServices: append(advertisedServices, svcName), - }, - }) - if err != nil { - return nil, fmt.Errorf("updating advertised Services: %w", err) - } + if err := s.advertiseService(svcName); err != nil { + return nil, fmt.Errorf("advertising Service: %w", err) } + onError = append(onError, func() { + s.mu.Lock() + defer s.mu.Unlock() + s.decrementServiceAdvertisementLocked(svcName) + }) - srvConfig := new(ipn.ServeConfig) - sc, srvConfigETag, err := s.lb.ServeConfigETag() + srvCfg := new(ipn.ServeConfig) + sc, srvCfgETag, err := s.lb.ServeConfigETag() if err != nil { return nil, fmt.Errorf("fetching current serve config: %w", err) } if sc.Valid() { - srvConfig = sc.AsStruct() + srvCfg = sc.AsStruct() } - fqdn := tailcfg.ServiceName(svcName).WithoutPrefix() + "." + st.CurrentTailnet.MagicDNSSuffix + fqdn := svcName.WithoutPrefix() + "." + st.CurrentTailnet.MagicDNSSuffix // svcAddr is used to implement Addr() on the returned listener. svcAddr := addr{ @@ -1537,6 +1859,13 @@ func (s *Server) ListenService(name string, mode ServiceMode) (*ServiceListener, if m.port() == 0 { return nil, errors.New("must specify a port to advertise") } + if svcCfg, ok := srvCfg.Services[svcName]; ok { + if _, handlerExists := svcCfg.TCP[m.port()]; handlerExists { + // We know that a handler must have been started in this runtime + // because serve config is reset on the first [Server.Up]. + return nil, errors.New("a Service handler already exists for this port") + } + } svcAddr.addr += ":" + strconv.Itoa(int(m.port())) } @@ -1545,11 +1874,12 @@ func (s *Server) ListenService(name string, mode ServiceMode) (*ServiceListener, if err != nil { return nil, fmt.Errorf("starting local listener: %w", err) } + onError = append(onError, func() { ln.Close() }) switch m := mode.(type) { case ServiceModeTCP: // Forward all connections from service-hostname:port to our socket. - srvConfig.SetTCPForwardingForService( + srvCfg.SetTCPForwardingForService( m.Port, ln.Addr().String(), m.TerminateTLS, tailcfg.ServiceName(svcName), m.PROXYProtocolVersion, st.CurrentTailnet.MagicDNSSuffix) case ServiceModeHTTP: @@ -1570,30 +1900,29 @@ func (s *Server) ListenService(name string, mode ServiceMode) (*ServiceListener, } else { h.Proxy += path } - srvConfig.SetWebHandler(&h, svcName, m.Port, path, m.HTTPS, mds) + srvCfg.SetWebHandler(&h, svcName.String(), m.Port, path, m.HTTPS, mds) } // We always need a root handler. if !haveRootHandler { h := ipn.HTTPHandler{Proxy: ln.Addr().String()} - srvConfig.SetWebHandler(&h, svcName, m.Port, "/", m.HTTPS, mds) + srvCfg.SetWebHandler(&h, svcName.String(), m.Port, "/", m.HTTPS, mds) } default: - ln.Close() return nil, fmt.Errorf("unknown ServiceMode type %T", m) } - if err := s.lb.SetServeConfig(srvConfig, srvConfigETag); err != nil { - ln.Close() + if err := s.lb.SetServeConfig(srvCfg, srvCfgETag); err != nil { return nil, err } - // TODO(hwh33,tailscale/corp#35860): clean up state (advertising prefs, - // serve config changes) when the returned listener is closed. - + onError = nil return &ServiceListener{ Listener: ln, FQDN: fqdn, addr: svcAddr, + s: s, + svcName: svcName, + mode: mode, }, nil } @@ -1850,9 +2179,9 @@ func (s *Server) GetRootPath() string { // debugging, probably not useful for production. // // Packets will be written to the pcap until the process exits. The pcap needs a Lua dissector -// to be installed in WireShark in order to decode properly: wgengine/capture/ts-dissector.lua +// to be installed in Wireshark in order to decode properly: wgengine/capture/ts-dissector.lua // in this repository. -// https://tailscale.com/kb/1023/troubleshooting/#can-i-examine-network-traffic-inside-the-encrypted-tunnel +// https://tailscale.com/docs/reference/troubleshooting/network-configuration/inspect-unencrypted-packets func (s *Server) CapturePcap(ctx context.Context, pcapFile string) error { stream, err := s.localClient.StreamDebugCapture(ctx) if err != nil { diff --git a/tsnet/tsnet_test.go b/tsnet/tsnet_test.go index 1cf4bf48fe5bd..4ee0ab10cc1a0 100644 --- a/tsnet/tsnet_test.go +++ b/tsnet/tsnet_test.go @@ -30,6 +30,7 @@ import ( "reflect" "runtime" "slices" + "strconv" "strings" "sync" "sync/atomic" @@ -58,6 +59,7 @@ import ( "tailscale.com/types/ipproto" "tailscale.com/types/key" "tailscale.com/types/logger" + "tailscale.com/types/netmap" "tailscale.com/types/views" "tailscale.com/util/mak" "tailscale.com/util/must" @@ -846,6 +848,8 @@ func TestFunnel(t *testing.T) { // after itself when closed. Specifically, changes made to the serve config // should be cleared. func TestFunnelClose(t *testing.T) { + tstest.Shard(t) + marshalServeConfig := func(t *testing.T, sc ipn.ServeConfigView) string { t.Helper() return string(must.Get(json.MarshalIndent(sc, "", "\t"))) @@ -874,7 +878,7 @@ func TestFunnelClose(t *testing.T) { // To obtain config the listener might want to clobber, we: // - run a listener // - grab the config - // - close the listener (clearing config) + // - close the listener (so we can run another on the same port) ln := must.Get(s.ListenFunnel("tcp", ":443")) before := s.lb.ServeConfig() ln.Close() @@ -932,33 +936,101 @@ func TestFunnelClose(t *testing.T) { // The listener should immediately return an error indicating closure. _, err := ln.Accept() - // Looking for a string in the error sucks, but it's supposed to stay - // consistent: - // https://github.com/golang/go/blob/108b333d510c1f60877ac917375d7931791acfe6/src/internal/poll/fd.go#L20-L24 - if err == nil || !strings.Contains(err.Error(), "use of closed network connection") { + if !errors.Is(err, net.ErrClosed) { t.Fatal("expected listener to be closed, got:", err) } }) } -func TestListenService(t *testing.T) { - // First test an error case which doesn't require all of the fancy setup. - t.Run("untagged_node_error", func(t *testing.T) { - ctx := t.Context() - - controlURL, _ := startControl(t) - serviceHost, _, _ := startServer(t, ctx, controlURL, "service-host") +// setUpServiceState performs all necessary state setup for testing with a +// Tailscale Service. When this function returns, the host will be able to +// advertise a Service (via [Server.ListenService]) and the client will be able +// to dial the Service via the Service name. +// +// extraSetup, when non-nil, can be used to perform additional state setup and +// this state will be observable by client and host when this function returns. +func setUpServiceState(t *testing.T, name, ip string, host, client *Server, + control *testcontrol.Server, extraSetup func(*testing.T, *testcontrol.Server)) { - ln, err := serviceHost.ListenService("svc:foo", ServiceModeTCP{Port: 8080}) - if ln != nil { - ln.Close() - } - if !errors.Is(err, ErrUntaggedServiceHost) { - t.Fatalf("expected %v, got %v", ErrUntaggedServiceHost, err) + t.Helper() + serviceName := tailcfg.ServiceName(name) + must.Do(serviceName.Validate()) + + // The Service host must have the 'service-host' capability, which + // is a mapping from the Service name to the Service VIP. + cm := host.lb.NetMap().SelfNode.CapMap() + svcIPMap := make(tailcfg.ServiceIPMappings) + if cm.Contains(tailcfg.NodeAttrServiceHost) { + parsed := must.Get(tailcfg.UnmarshalNodeCapViewJSON[tailcfg.ServiceIPMappings](cm, tailcfg.NodeAttrServiceHost)) + if len(parsed) != 1 { + t.Fatalf("expected only one capability for %v, got %d", tailcfg.NodeAttrServiceHost, len(parsed)) + } + svcIPMap = parsed[0] + } + svcIPMap[serviceName] = []netip.Addr{netip.MustParseAddr(ip)} + svcIPMapJSON := must.Get(json.Marshal(svcIPMap)) + newCM := cm.AsMap() + mak.Set(&newCM, tailcfg.NodeAttrServiceHost, []tailcfg.RawMessage{tailcfg.RawMessage(svcIPMapJSON)}) + control.SetNodeCapMap(host.lb.NodeKey(), newCM) + + // The Service host must be allowed to advertise the Service VIP. + subnetRoutes := []netip.Prefix{netip.MustParsePrefix(ip + `/32`)} + selfAddresses := host.lb.NetMap().SelfNode.Addresses() + for _, existingRoute := range host.lb.NetMap().SelfNode.AllowedIPs().All() { + if views.SliceContains(selfAddresses, existingRoute) { + continue } + subnetRoutes = append(subnetRoutes, existingRoute) + } + control.SetSubnetRoutes(host.lb.NodeKey(), subnetRoutes) + + // The Service host must be a tagged node (any tag will do). + serviceHostNode := control.Node(host.lb.NodeKey()) + serviceHostNode.Tags = append(serviceHostNode.Tags, "some-tag") + control.UpdateNode(serviceHostNode) + + // The service client must accept routes advertised by other nodes + // (RouteAll is equivalent to --accept-routes). + must.Get(client.localClient.EditPrefs(t.Context(), &ipn.MaskedPrefs{ + RouteAllSet: true, + Prefs: ipn.Prefs{ + RouteAll: true, + }, + })) + + // Do the test's extra setup before configuring DNS. This allows + // us to use the configured DNS records as sentinel values when + // waiting for all of this setup to be visible to test nodes. + if extraSetup != nil { + extraSetup(t, control) + } + + // Set up DNS for our Service. + control.AddDNSRecords(tailcfg.DNSRecord{ + Name: serviceName.WithoutPrefix() + "." + control.MagicDNSDomain, + Value: ip, }) - // Now on to the fancier tests. + // Wait until both nodes have up-to-date netmaps before + // proceeding with the test. + netmapUpToDate := func(nm *netmap.NetworkMap) bool { + return nm != nil && slices.ContainsFunc(nm.DNS.ExtraRecords, func(r tailcfg.DNSRecord) bool { + return r.Value == ip + }) + } + waitForLatestNetmap := func(t *testing.T, s *Server) { + t.Helper() + w := must.Get(s.localClient.WatchIPNBus(t.Context(), ipn.NotifyInitialNetMap)) + defer w.Close() + for n := must.Get(w.Next()); !netmapUpToDate(n.NetMap); n = must.Get(w.Next()) { + } + } + waitForLatestNetmap(t, client) + waitForLatestNetmap(t, host) +} + +func TestListenService(t *testing.T) { + tstest.Shard(t) type dialFn func(context.Context, string, string) (net.Conn, error) @@ -1224,86 +1296,287 @@ func TestListenService(t *testing.T) { // We run each test with and without a TUN device ([Server.Tun]). // Note that this TUN device is distinct from TUN mode for Services. doTest := func(t *testing.T, withTUNDevice bool) { - ctx := t.Context() - lt := setupTwoClientTest(t, withTUNDevice) serviceHost := lt.s2 serviceClient := lt.s1 - control := lt.control - const serviceName = tailcfg.ServiceName("svc:foo") + const serviceName = "svc:foo" const serviceVIP = "100.11.22.33" - // == Set up necessary state in our mock == + setUpServiceState(t, serviceName, serviceVIP, + serviceHost, serviceClient, lt.control, tt.extraSetup) - // The Service host must have the 'service-host' capability, which - // is a mapping from the Service name to the Service VIP. - var serviceHostCaps map[tailcfg.ServiceName]views.Slice[netip.Addr] - mak.Set(&serviceHostCaps, serviceName, views.SliceOf([]netip.Addr{netip.MustParseAddr(serviceVIP)})) - j := must.Get(json.Marshal(serviceHostCaps)) - cm := serviceHost.lb.NetMap().SelfNode.CapMap().AsMap() - mak.Set(&cm, tailcfg.NodeAttrServiceHost, []tailcfg.RawMessage{tailcfg.RawMessage(j)}) - control.SetNodeCapMap(serviceHost.lb.NodeKey(), cm) + listeners := make([]*ServiceListener, 0, len(tt.modes)) + for _, input := range tt.modes { + ln := must.Get(serviceHost.ListenService(serviceName, input)) + defer ln.Close() + listeners = append(listeners, ln) + } - // The Service host must be allowed to advertise the Service VIP. - control.SetSubnetRoutes(serviceHost.lb.NodeKey(), []netip.Prefix{ - netip.MustParsePrefix(serviceVIP + `/32`), - }) + tt.run(t, listeners, serviceClient) + } - // The Service host must be a tagged node (any tag will do). - serviceHostNode := control.Node(serviceHost.lb.NodeKey()) - serviceHostNode.Tags = append(serviceHostNode.Tags, "some-tag") - control.UpdateNode(serviceHostNode) - - // The service client must accept routes advertised by other nodes - // (RouteAll is equivalent to --accept-routes). - must.Get(serviceClient.localClient.EditPrefs(ctx, &ipn.MaskedPrefs{ - RouteAllSet: true, - Prefs: ipn.Prefs{ - RouteAll: true, - }, - })) + t.Run("TUN", func(t *testing.T) { doTest(t, true) }) + t.Run("netstack", func(t *testing.T) { doTest(t, false) }) + }) + } - // Set up DNS for our Service. - control.AddDNSRecords(tailcfg.DNSRecord{ - Name: serviceName.WithoutPrefix() + "." + control.MagicDNSDomain, - Value: serviceVIP, - }) + // Error cases. + t.Run("untagged_node_error", func(t *testing.T) { + ctx := t.Context() + + controlURL, _ := startControl(t) + serviceHost, _, _ := startServer(t, ctx, controlURL, "service-host") + + ln, err := serviceHost.ListenService("svc:foo", ServiceModeTCP{Port: 8080}) + if ln != nil { + ln.Close() + } + if !errors.Is(err, ErrUntaggedServiceHost) { + t.Fatalf("expected %v, got %v", ErrUntaggedServiceHost, err) + } + }) + t.Run("duplicate_listeners", func(t *testing.T) { + ctx := t.Context() + + const serviceName = "svc:foo" + + controlURL, control := startControl(t) + serviceHost, _, _ := startServer(t, ctx, controlURL, "service-host") + serviceClient, _, _ := startServer(t, ctx, controlURL, "service-client") + + setUpServiceState(t, serviceName, "1.2.3.4", serviceHost, serviceClient, control, nil) + + ln := must.Get(serviceHost.ListenService(serviceName, ServiceModeTCP{Port: 8080})) + defer ln.Close() + + ln, err := serviceHost.ListenService(serviceName, ServiceModeTCP{Port: 8080}) + if ln != nil { + ln.Close() + } + if err == nil { + t.Fatal("expected error for redundant listener") + } + + // An HTTP listener on the same port should also collide + ln, err = serviceHost.ListenService(serviceName, ServiceModeHTTP{Port: 8080}) + if ln != nil { + ln.Close() + } + if err == nil { + t.Fatal("expected error for redundant listener") + } + }) + + t.Run("multiple_services", func(t *testing.T) { + const numberServices = 10 + const port = 80 - if tt.extraSetup != nil { - tt.extraSetup(t, control) + lt := setupTwoClientTest(t, false) + serviceHost := lt.s2 + serviceClient := lt.s1 + + names := make([]string, numberServices) + fqdns := make([]string, numberServices) + for i := range numberServices { + serviceName := "svc:foo" + strconv.Itoa(i+1) + serviceIP := `11.22.33.` + strconv.Itoa(i+1) + + setUpServiceState(t, serviceName, serviceIP, serviceHost, serviceClient, lt.control, nil) + ln := must.Get(serviceHost.ListenService(serviceName, ServiceModeTCP{Port: port})) + defer ln.Close() + names[i] = serviceName + fqdns[i] = ln.FQDN + + go func() { + // Accept a single connection, echo, then return. + conn, err := ln.Accept() + if err != nil { + t.Errorf("accept error from %v: %v", serviceName, err) + return } + defer conn.Close() + if _, err := io.Copy(conn, conn); err != nil { + t.Errorf("copy error from %v: %v", serviceName, err) + } + }() + } + for i := range numberServices { + msg := []byte("hello, " + fqdns[i]) + + conn := must.Get(serviceClient.Dial(t.Context(), "tcp", fqdns[i]+":"+strconv.Itoa(port))) + defer conn.Close() + must.Get(conn.Write(msg)) + buf := make([]byte, len(msg)) + n := must.Get(conn.Read(buf)) + if !bytes.Equal(buf[:n], msg) { + t.Fatalf("did not receive expected message:\n\tgot: %s\n\twant: %s\n", buf[:n], msg) + } + } + + // Each of the Services should be advertised by our Service host. + advertised := serviceHost.lb.Prefs().AdvertiseServices() + for _, name := range names { + if !views.SliceContains(advertised, name) { + t.Log("advertised Services:", advertised) + t.Fatalf("did not find %q in advertised Services", name) + } + } + }) +} - // Wait until both nodes have up-to-date netmaps before - // proceeding with the test. - netmapUpToDate := func(s *Server) bool { - nm := s.lb.NetMap() - return slices.ContainsFunc(nm.DNS.ExtraRecords, func(r tailcfg.DNSRecord) bool { - return r.Value == serviceVIP - }) +func TestListenServiceClose(t *testing.T) { + tstest.Shard(t) + const serviceName = "svc:foo" + + diffServeConfig := func(a, b ipn.ServeConfigView) string { + // We treat a mapping from svc:foo to nil or the zero value as if it + // didn't exist at all. This is consistent with how the local backend + // treats service configs when nil or zero. + tr := cmp.Transformer("DeleteEmptyServices", func(m map[tailcfg.ServiceName]*ipn.ServiceConfig) map[tailcfg.ServiceName]*ipn.ServiceConfig { + mCopy := map[tailcfg.ServiceName]*ipn.ServiceConfig{} + for k, v := range m { + if v == nil { + continue } - for !netmapUpToDate(serviceClient) { - time.Sleep(10 * time.Millisecond) + if rv := reflect.ValueOf(*v); rv.IsValid() && rv.IsZero() { + continue } - for !netmapUpToDate(serviceHost) { - time.Sleep(10 * time.Millisecond) + mCopy[k] = v + } + return mCopy + }) + + return cmp.Diff(a.AsStruct(), b.AsStruct(), tr) + } + + tests := []struct { + name string + run func(t *testing.T, serviceHost *Server) + }{ + { + name: "TCP", + run: func(t *testing.T, s *Server) { + before := s.lb.ServeConfig() + ln := must.Get(s.ListenService(serviceName, ServiceModeTCP{Port: 8080})) + ln.Close() + after := s.lb.ServeConfig() + if diff := diffServeConfig(after, before); diff != "" { + t.Fatalf("expected serve config to be unchanged after close (-got, +want):\n%s", diff) + } + }, + }, + { + name: "HTTP", + run: func(t *testing.T, s *Server) { + before := s.lb.ServeConfig() + ln := must.Get(s.ListenService(serviceName, ServiceModeHTTP{Port: 8080})) + ln.Close() + after := s.lb.ServeConfig() + if diff := diffServeConfig(after, before); diff != "" { + t.Fatalf("expected serve config to be unchanged after close (-got, +want):\n%s", diff) + } + }, + }, + { + // Closing one listener should not affect config for another listener. + name: "two_listeners", + run: func(t *testing.T, s *Server) { + // Start a listener on 443. + ln1 := must.Get(s.ListenService(serviceName, ServiceModeTCP{Port: 443})) + defer ln1.Close() + + // Save the serve config for this original listener. + before := s.lb.ServeConfig() + + // Now start and close a new listener on a different port. + ln2 := must.Get(s.ListenService(serviceName, ServiceModeTCP{Port: 8080})) + ln2.Close() + + // The serve config for the original listener should be intact. + after := s.lb.ServeConfig() + if diff := diffServeConfig(after, before); diff != "" { + t.Fatalf("expected existing config to remain intact (-got, +want):\n%s", diff) + } + }, + }, + { + // It should be possible to close a listener and free system + // resources even when the Server has been closed (or the listener + // should be automatically closed). + name: "after_server_close", + run: func(t *testing.T, s *Server) { + ln := must.Get(s.ListenService(serviceName, ServiceModeTCP{Port: 8080})) + + // Close the server, then close the listener. + must.Do(s.Close()) + // We don't care whether we get an error from the listener closing. + t.Log("close error:", ln.Close()) + + // The listener should immediately return an error indicating closure. + _, err := ln.Accept() + if !errors.Is(err, net.ErrClosed) { + t.Fatal("expected listener to be closed, got:", err) + } + }, + }, + { + // Regression test for https://github.com/tailscale/tailscale/issues/19169, + // in which concurrent ServiceListener.Close calls (by different + // listeners) would fail. + name: "concurrent_close", + run: func(t *testing.T, s *Server) { + const concurrentCloseCalls = 100 + + readyGroup := new(sync.WaitGroup) + closedGroup := new(sync.WaitGroup) + closeThemAll := make(chan (struct{})) + errC := make(chan error, concurrentCloseCalls) + for i := range concurrentCloseCalls { + readyGroup.Add(1) + closedGroup.Add(1) + ln := must.Get(s.ListenService(serviceName, ServiceModeTCP{ + Port: uint16(i + 1), + })) + go func() { + readyGroup.Done() + <-closeThemAll + errC <- ln.Close() + closedGroup.Done() + }() } - // == Done setting up mock state == + readyGroup.Wait() + close(closeThemAll) + closedGroup.Wait() + close(errC) - // Start the Service listeners. - listeners := make([]*ServiceListener, 0, len(tt.modes)) - for _, input := range tt.modes { - ln := must.Get(serviceHost.ListenService(serviceName.String(), input)) - defer ln.Close() - listeners = append(listeners, ln) + var errs []error + for err := range errC { + if err != nil { + errs = append(errs, err) + } + } + if len(errs) > 0 { + t.Fatalf("%d close errors; sample: %v", len(errs), errs[0]) + } + if diff := diffServeConfig(s.lb.ServeConfig(), (&ipn.ServeConfig{}).View()); diff != "" { + t.Fatalf("expected empty config (-got, +want):\n%s", diff) } + }, + }, + } - tt.run(t, listeners, serviceClient) - } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := t.Context() - t.Run("TUN", func(t *testing.T) { doTest(t, true) }) - t.Run("netstack", func(t *testing.T) { doTest(t, false) }) + controlURL, control := startControl(t) + serviceHost, _, _ := startServer(t, ctx, controlURL, "service-host") + serviceClient, _, _ := startServer(t, ctx, controlURL, "service-client") + setUpServiceState(t, serviceName, "1.2.3.4", serviceHost, serviceClient, control, nil) + + tt.run(t, serviceHost) }) } } @@ -2598,7 +2871,7 @@ func buildDNSQuery(name string, srcIP netip.Addr) []byte { 0x00, 0x01, // QDCOUNT: 1 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // ANCOUNT, NSCOUNT, ARCOUNT } - for _, label := range strings.Split(name, ".") { + for label := range strings.SplitSeq(name, ".") { dns = append(dns, byte(len(label))) dns = append(dns, label...) } @@ -2631,8 +2904,17 @@ func TestDeps(t *testing.T) { deptest.DepChecker{ GOOS: "linux", GOARCH: "amd64", + BadDeps: map[string]string{ + "golang.org/x/crypto/ssh": "tsnet should not depend on SSH", + "golang.org/x/crypto/ssh/internal/bcrypt_pbkdf": "tsnet should not depend on SSH", + "tailscale.com/ipn/store/awsstore": "tsnet callers wanting AWS state storage should import awsstore themselves", + "tailscale.com/ipn/store/kubestore": "tsnet callers wanting Kubernetes state storage should import kubestore themselves", + "tailscale.com/wif": "tsnet callers wanting workload identity federation should import tailscale.com/feature/identityfederation themselves", + }, OnDep: func(dep string) { - if strings.Contains(dep, "portlist") { + if strings.Contains(dep, "portlist") || + strings.Contains(dep, "github.com/aws/") || + strings.Contains(dep, "k8s.io/") { t.Errorf("unexpected dep: %q", dep) } }, @@ -2656,7 +2938,7 @@ func TestResolveAuthKey(t *testing.T) { wantErrContains string }{ { - name: "successful resolution via OAuth client secret", + name: "success-oauth-client-secret", clientSecret: "tskey-client-secret-123", oauthAvailable: true, resolveViaOAuth: func(ctx context.Context, clientSecret string, tags []string) (string, error) { @@ -2669,7 +2951,7 @@ func TestResolveAuthKey(t *testing.T) { wantErrContains: "", }, { - name: "failing resolution via OAuth client secret", + name: "fail-oauth-client-secret", clientSecret: "tskey-client-secret-123", oauthAvailable: true, resolveViaOAuth: func(ctx context.Context, clientSecret string, tags []string) (string, error) { @@ -2678,7 +2960,7 @@ func TestResolveAuthKey(t *testing.T) { wantErrContains: "resolution failed", }, { - name: "successful resolution via federated ID token", + name: "success-federated-id-token", clientID: "client-id-123", idToken: "id-token-456", wifAvailable: true, @@ -2695,7 +2977,7 @@ func TestResolveAuthKey(t *testing.T) { wantErrContains: "", }, { - name: "successful resolution via federated audience", + name: "success-federated-audience", clientID: "client-id-123", audience: "api.tailscale.com", wifAvailable: true, @@ -2712,7 +2994,7 @@ func TestResolveAuthKey(t *testing.T) { wantErrContains: "", }, { - name: "failing resolution via federated ID token", + name: "fail-federated-id-token", clientID: "client-id-123", idToken: "id-token-456", wifAvailable: true, @@ -2722,7 +3004,7 @@ func TestResolveAuthKey(t *testing.T) { wantErrContains: "resolution failed", }, { - name: "empty client ID with ID token", + name: "empty-client-id-with-token", clientID: "", idToken: "id-token-456", wifAvailable: true, @@ -2732,7 +3014,7 @@ func TestResolveAuthKey(t *testing.T) { wantErrContains: "empty", }, { - name: "empty client ID with audience", + name: "empty-client-id-with-audience", clientID: "", audience: "api.tailscale.com", wifAvailable: true, @@ -2742,7 +3024,7 @@ func TestResolveAuthKey(t *testing.T) { wantErrContains: "empty", }, { - name: "empty ID token", + name: "empty-id-token", clientID: "client-id-123", idToken: "", wifAvailable: true, @@ -2752,7 +3034,7 @@ func TestResolveAuthKey(t *testing.T) { wantErrContains: "empty", }, { - name: "audience with ID token", + name: "audience-with-id-token", clientID: "client-id-123", idToken: "id-token-456", audience: "api.tailscale.com", @@ -2763,7 +3045,7 @@ func TestResolveAuthKey(t *testing.T) { wantErrContains: "only one of ID token and audience", }, { - name: "workload identity resolution skipped if resolution via OAuth token succeeds", + name: "wif-skipped-oauth-succeeds", clientSecret: "tskey-client-secret-123", oauthAvailable: true, resolveViaOAuth: func(ctx context.Context, clientSecret string, tags []string) (string, error) { @@ -2780,7 +3062,7 @@ func TestResolveAuthKey(t *testing.T) { wantErrContains: "", }, { - name: "workload identity resolution skipped if resolution via OAuth token fails", + name: "wif-skipped-oauth-fails", clientID: "tskey-client-id-123", idToken: "", oauthAvailable: true, @@ -2794,7 +3076,7 @@ func TestResolveAuthKey(t *testing.T) { wantErrContains: "failed", }, { - name: "authkey set and no resolution available", + name: "authkey-set-no-resolution", authKey: "tskey-auth-123", oauthAvailable: false, wifAvailable: false, @@ -2802,14 +3084,14 @@ func TestResolveAuthKey(t *testing.T) { wantErrContains: "", }, { - name: "no authkey set and no resolution available", + name: "no-authkey-no-resolution", oauthAvailable: false, wifAvailable: false, wantAuthKey: "", wantErrContains: "", }, { - name: "authkey is client secret and resolution via OAuth client secret succeeds", + name: "authkey-client-secret-oauth-succeeds", authKey: "tskey-client-secret-123", oauthAvailable: true, resolveViaOAuth: func(ctx context.Context, clientSecret string, tags []string) (string, error) { @@ -2822,7 +3104,7 @@ func TestResolveAuthKey(t *testing.T) { wantErrContains: "", }, { - name: "authkey is client secret but resolution via OAuth client secret fails", + name: "authkey-client-secret-oauth-fails", authKey: "tskey-client-secret-123", oauthAvailable: true, resolveViaOAuth: func(ctx context.Context, clientSecret string, tags []string) (string, error) { @@ -3005,12 +3287,12 @@ func TestListenUnspecifiedAddr(t *testing.T) { t.Run("Netstack", func(t *testing.T) { lt := setupTwoClientTest(t, false) - t.Run("0.0.0.0", func(t *testing.T) { testUnspec(t, lt, "0.0.0.0:8080", "8080") }) + t.Run("v4-unspec", func(t *testing.T) { testUnspec(t, lt, "0.0.0.0:8080", "8080") }) t.Run("::", func(t *testing.T) { testUnspec(t, lt, "[::]:8081", "8081") }) }) t.Run("TUN", func(t *testing.T) { lt := setupTwoClientTest(t, true) - t.Run("0.0.0.0", func(t *testing.T) { testUnspec(t, lt, "0.0.0.0:8080", "8080") }) + t.Run("v4-unspec", func(t *testing.T) { testUnspec(t, lt, "0.0.0.0:8080", "8080") }) t.Run("::", func(t *testing.T) { testUnspec(t, lt, "[::]:8081", "8081") }) }) } diff --git a/tstest/clock.go b/tstest/clock.go index 5742c6e5aeda1..1f88fb0a28611 100644 --- a/tstest/clock.go +++ b/tstest/clock.go @@ -20,6 +20,9 @@ type ClockOpts struct { // to Clock.Now. If you are passing a value here, set an explicit // timezone, otherwise the test may be non-deterministic when TZ environment // variable is set to different values. The default time is in UTC. + // + // If you do not pass an explicit Start time, the clock will start at the + // current UTC time. Start time.Time // Step is the amount of time the Clock will advance whenever Clock.Now is diff --git a/tstest/clock_test.go b/tstest/clock_test.go index cdfc2319ac115..5a05d57bbee71 100644 --- a/tstest/clock_test.go +++ b/tstest/clock_test.go @@ -22,7 +22,7 @@ func TestClockWithDefinedStartTime(t *testing.T) { wants []time.Time // The return values of sequential calls to Now(). }{ { - name: "increment ms", + name: "increment-ms", start: time.Unix(12345, 1000), step: 1000, wants: []time.Time{ @@ -33,7 +33,7 @@ func TestClockWithDefinedStartTime(t *testing.T) { }, }, { - name: "increment second", + name: "increment-second", start: time.Unix(12345, 1000), step: time.Second, wants: []time.Time{ @@ -44,7 +44,7 @@ func TestClockWithDefinedStartTime(t *testing.T) { }, }, { - name: "no increment", + name: "no-increment", start: time.Unix(12345, 1000), wants: []time.Time{ time.Unix(12345, 1000), @@ -91,7 +91,7 @@ func TestClockWithDefaultStartTime(t *testing.T) { wants []time.Duration // The return values of sequential calls to Now() after added to Start() }{ { - name: "increment ms", + name: "increment-ms", step: 1000, wants: []time.Duration{ 0, @@ -101,7 +101,7 @@ func TestClockWithDefaultStartTime(t *testing.T) { }, }, { - name: "increment second", + name: "increment-second", step: time.Second, wants: []time.Duration{ 0 * time.Second, @@ -111,7 +111,7 @@ func TestClockWithDefaultStartTime(t *testing.T) { }, }, { - name: "no increment", + name: "no-increment", wants: []time.Duration{0, 0, 0, 0}, }, } @@ -177,7 +177,7 @@ func TestClockSetStep(t *testing.T) { wants []time.Time // The return values of sequential calls to Now(). }{ { - name: "increment ms then s", + name: "increment-ms-then-s", start: time.Unix(12345, 1000), step: 1000, stepChanges: []stepInfo{ @@ -198,7 +198,7 @@ func TestClockSetStep(t *testing.T) { }, }, { - name: "multiple changes over time", + name: "multiple-changes-over-time", start: time.Unix(12345, 1000), step: 1, stepChanges: []stepInfo{ @@ -227,7 +227,7 @@ func TestClockSetStep(t *testing.T) { }, }, { - name: "multiple changes at once", + name: "multiple-changes-at-once", start: time.Unix(12345, 1000), step: 1, stepChanges: []stepInfo{ @@ -252,7 +252,7 @@ func TestClockSetStep(t *testing.T) { }, }, { - name: "changes at start", + name: "changes-at-start", start: time.Unix(12345, 1000), step: 0, stepChanges: []stepInfo{ @@ -325,7 +325,7 @@ func TestClockAdvance(t *testing.T) { wants []time.Time // The return values of sequential calls to Now(). }{ { - name: "increment ms then advance 1s", + name: "increment-ms-then-advance-1s", start: time.Unix(12345, 1000), step: 1000, advances: []advanceInfo{ @@ -346,7 +346,7 @@ func TestClockAdvance(t *testing.T) { }, }, { - name: "multiple advances over time", + name: "multiple-advances-over-time", start: time.Unix(12345, 1000), step: 1, advances: []advanceInfo{ @@ -375,7 +375,7 @@ func TestClockAdvance(t *testing.T) { }, }, { - name: "multiple advances at once", + name: "multiple-advances-at-once", start: time.Unix(12345, 1000), step: 1, advances: []advanceInfo{ @@ -400,7 +400,7 @@ func TestClockAdvance(t *testing.T) { }, }, { - name: "changes at start", + name: "changes-at-start", start: time.Unix(12345, 1000), step: 5, advances: []advanceInfo{ @@ -489,7 +489,7 @@ func TestSingleTicker(t *testing.T) { steps []testStep }{ { - name: "no tick advance", + name: "no-tick-advance", start: time.Unix(12345, 0), period: time.Second, steps: []testStep{ @@ -500,7 +500,7 @@ func TestSingleTicker(t *testing.T) { }, }, { - name: "no tick step", + name: "no-tick-step", start: time.Unix(12345, 0), step: time.Second - 1, period: time.Second, @@ -514,7 +514,7 @@ func TestSingleTicker(t *testing.T) { }, }, { - name: "single tick advance exact", + name: "single-tick-advance-exact", start: time.Unix(12345, 0), period: time.Second, steps: []testStep{ @@ -526,7 +526,7 @@ func TestSingleTicker(t *testing.T) { }, }, { - name: "single tick advance extra", + name: "single-tick-advance-extra", start: time.Unix(12345, 0), period: time.Second, steps: []testStep{ @@ -538,7 +538,7 @@ func TestSingleTicker(t *testing.T) { }, }, { - name: "single tick step exact", + name: "single-tick-step-exact", start: time.Unix(12345, 0), step: time.Second, period: time.Second, @@ -553,7 +553,7 @@ func TestSingleTicker(t *testing.T) { }, }, { - name: "single tick step extra", + name: "single-tick-step-extra", start: time.Unix(12345, 0), step: time.Second + 1, period: time.Second, @@ -568,7 +568,7 @@ func TestSingleTicker(t *testing.T) { }, }, { - name: "single tick per advance", + name: "single-tick-per-advance", start: time.Unix(12345, 0), period: 3 * time.Second, steps: []testStep{ @@ -597,7 +597,7 @@ func TestSingleTicker(t *testing.T) { }, }, { - name: "single tick per step", + name: "single-tick-per-step", start: time.Unix(12345, 0), step: 2 * time.Second, period: 3 * time.Second, @@ -626,7 +626,7 @@ func TestSingleTicker(t *testing.T) { }, }, { - name: "multiple tick per advance", + name: "multiple-tick-per-advance", start: time.Unix(12345, 0), period: time.Second, channelSize: 3, @@ -655,7 +655,7 @@ func TestSingleTicker(t *testing.T) { }, }, { - name: "multiple tick per step", + name: "multiple-tick-per-step", start: time.Unix(12345, 0), step: 3 * time.Second, period: 2 * time.Second, @@ -723,7 +723,7 @@ func TestSingleTicker(t *testing.T) { }, }, { - name: "reset while running", + name: "reset-while-running", start: time.Unix(12345, 0), period: 2 * time.Second, steps: []testStep{ @@ -763,7 +763,7 @@ func TestSingleTicker(t *testing.T) { }, }, { - name: "reset while stopped", + name: "reset-while-stopped", start: time.Unix(12345, 0), step: time.Second, period: 2 * time.Second, @@ -803,7 +803,7 @@ func TestSingleTicker(t *testing.T) { }, }, { - name: "reset absolute", + name: "reset-absolute", start: time.Unix(12345, 0), step: time.Second, period: 2 * time.Second, @@ -841,7 +841,7 @@ func TestSingleTicker(t *testing.T) { }, }, { - name: "follow real time", + name: "follow-real-time", realTimeOpts: new(ClockOpts), start: time.Unix(12345, 0), period: 2 * time.Second, @@ -965,7 +965,7 @@ func TestSingleTimer(t *testing.T) { steps []testStep }{ { - name: "no tick advance", + name: "no-tick-advance", start: time.Unix(12345, 0), delay: time.Second, steps: []testStep{ @@ -976,7 +976,7 @@ func TestSingleTimer(t *testing.T) { }, }, { - name: "no tick step", + name: "no-tick-step", start: time.Unix(12345, 0), step: time.Second - 1, delay: time.Second, @@ -990,7 +990,7 @@ func TestSingleTimer(t *testing.T) { }, }, { - name: "single tick advance exact", + name: "single-tick-advance-exact", start: time.Unix(12345, 0), delay: time.Second, steps: []testStep{ @@ -1006,7 +1006,7 @@ func TestSingleTimer(t *testing.T) { }, }, { - name: "single tick advance extra", + name: "single-tick-advance-extra", start: time.Unix(12345, 0), delay: time.Second, steps: []testStep{ @@ -1022,7 +1022,7 @@ func TestSingleTimer(t *testing.T) { }, }, { - name: "single tick step exact", + name: "single-tick-step-exact", start: time.Unix(12345, 0), step: time.Second, delay: time.Second, @@ -1040,7 +1040,7 @@ func TestSingleTimer(t *testing.T) { }, }, { - name: "single tick step extra", + name: "single-tick-step-extra", start: time.Unix(12345, 0), step: time.Second + 1, delay: time.Second, @@ -1058,7 +1058,7 @@ func TestSingleTimer(t *testing.T) { }, }, { - name: "reset for single tick per advance", + name: "reset-for-single-tick-per-advance", start: time.Unix(12345, 0), delay: 3 * time.Second, steps: []testStep{ @@ -1093,7 +1093,7 @@ func TestSingleTimer(t *testing.T) { }, }, { - name: "reset for single tick per step", + name: "reset-for-single-tick-per-step", start: time.Unix(12345, 0), step: 2 * time.Second, delay: 3 * time.Second, @@ -1124,7 +1124,7 @@ func TestSingleTimer(t *testing.T) { }, }, { - name: "reset while active", + name: "reset-while-active", start: time.Unix(12345, 0), step: 2 * time.Second, delay: 3 * time.Second, @@ -1155,7 +1155,7 @@ func TestSingleTimer(t *testing.T) { }, }, { - name: "stop after fire", + name: "stop-after-fire", start: time.Unix(12345, 0), step: 2 * time.Second, delay: time.Second, @@ -1181,7 +1181,7 @@ func TestSingleTimer(t *testing.T) { }, }, { - name: "stop before fire", + name: "stop-before-fire", start: time.Unix(12345, 0), step: 2 * time.Second, delay: time.Second, @@ -1207,7 +1207,7 @@ func TestSingleTimer(t *testing.T) { }, }, { - name: "stop after reset", + name: "stop-after-reset", start: time.Unix(12345, 0), step: 2 * time.Second, delay: time.Second, @@ -1235,7 +1235,7 @@ func TestSingleTimer(t *testing.T) { }, }, { - name: "reset while running", + name: "reset-while-running", start: time.Unix(12345, 0), delay: 2 * time.Second, steps: []testStep{ @@ -1275,7 +1275,7 @@ func TestSingleTimer(t *testing.T) { }, }, { - name: "reset while stopped", + name: "reset-while-stopped", start: time.Unix(12345, 0), step: time.Second, delay: 2 * time.Second, @@ -1310,7 +1310,7 @@ func TestSingleTimer(t *testing.T) { }, }, { - name: "reset absolute", + name: "reset-absolute", start: time.Unix(12345, 0), step: time.Second, delay: 2 * time.Second, @@ -1344,7 +1344,7 @@ func TestSingleTimer(t *testing.T) { }, }, { - name: "follow real time", + name: "follow-real-time", realTimeOpts: new(ClockOpts), start: time.Unix(12345, 0), delay: 2 * time.Second, @@ -1705,7 +1705,7 @@ func TestClockFollowRealTime(t *testing.T) { wants []time.Time // The return values of sequential calls to Now(). }{ { - name: "increment ms then advance 1s", + name: "increment-ms-then-advance-1s", start: time.Unix(12345, 1000), wantStart: time.Unix(12345, 1000), advances: []advanceInfo{ @@ -1750,7 +1750,7 @@ func TestClockFollowRealTime(t *testing.T) { }, }, { - name: "multiple advances over time", + name: "multiple-advances-over-time", start: time.Unix(12345, 1000), wantStart: time.Unix(12345, 1000), advances: []advanceInfo{ @@ -1795,7 +1795,7 @@ func TestClockFollowRealTime(t *testing.T) { }, }, { - name: "multiple advances at once", + name: "multiple-advances-at-once", start: time.Unix(12345, 1000), wantStart: time.Unix(12345, 1000), advances: []advanceInfo{ @@ -1828,7 +1828,7 @@ func TestClockFollowRealTime(t *testing.T) { }, }, { - name: "changes at start", + name: "changes-at-start", start: time.Unix(12345, 1000), wantStart: time.Unix(12345, 1000), advances: []advanceInfo{ @@ -1861,7 +1861,7 @@ func TestClockFollowRealTime(t *testing.T) { }, }, { - name: "start from current time", + name: "start-from-current-time", realTimeClockOpts: ClockOpts{ Start: time.Unix(12345, 0), }, @@ -1966,7 +1966,7 @@ func TestAfterFunc(t *testing.T) { steps []testStep }{ { - name: "no tick advance", + name: "no-tick-advance", start: time.Unix(12345, 0), delay: time.Second, steps: []testStep{ @@ -1977,7 +1977,7 @@ func TestAfterFunc(t *testing.T) { }, }, { - name: "no tick step", + name: "no-tick-step", start: time.Unix(12345, 0), step: time.Second - 1, delay: time.Second, @@ -1991,7 +1991,7 @@ func TestAfterFunc(t *testing.T) { }, }, { - name: "single tick advance exact", + name: "single-tick-advance-exact", start: time.Unix(12345, 0), delay: time.Second, steps: []testStep{ @@ -2007,7 +2007,7 @@ func TestAfterFunc(t *testing.T) { }, }, { - name: "single tick advance extra", + name: "single-tick-advance-extra", start: time.Unix(12345, 0), delay: time.Second, steps: []testStep{ @@ -2023,7 +2023,7 @@ func TestAfterFunc(t *testing.T) { }, }, { - name: "single tick step exact", + name: "single-tick-step-exact", start: time.Unix(12345, 0), step: time.Second, delay: time.Second, @@ -2041,7 +2041,7 @@ func TestAfterFunc(t *testing.T) { }, }, { - name: "single tick step extra", + name: "single-tick-step-extra", start: time.Unix(12345, 0), step: time.Second + 1, delay: time.Second, @@ -2059,7 +2059,7 @@ func TestAfterFunc(t *testing.T) { }, }, { - name: "reset for single tick per advance", + name: "reset-for-single-tick-per-advance", start: time.Unix(12345, 0), delay: 3 * time.Second, steps: []testStep{ @@ -2094,7 +2094,7 @@ func TestAfterFunc(t *testing.T) { }, }, { - name: "reset for single tick per step", + name: "reset-for-single-tick-per-step", start: time.Unix(12345, 0), step: 2 * time.Second, delay: 3 * time.Second, @@ -2125,7 +2125,7 @@ func TestAfterFunc(t *testing.T) { }, }, { - name: "reset while active", + name: "reset-while-active", start: time.Unix(12345, 0), step: 2 * time.Second, delay: 3 * time.Second, @@ -2156,7 +2156,7 @@ func TestAfterFunc(t *testing.T) { }, }, { - name: "stop after fire", + name: "stop-after-fire", start: time.Unix(12345, 0), step: 2 * time.Second, delay: time.Second, @@ -2182,7 +2182,7 @@ func TestAfterFunc(t *testing.T) { }, }, { - name: "stop before fire", + name: "stop-before-fire", start: time.Unix(12345, 0), step: 2 * time.Second, delay: time.Second, @@ -2208,7 +2208,7 @@ func TestAfterFunc(t *testing.T) { }, }, { - name: "stop after reset", + name: "stop-after-reset", start: time.Unix(12345, 0), step: 2 * time.Second, delay: time.Second, @@ -2236,7 +2236,7 @@ func TestAfterFunc(t *testing.T) { }, }, { - name: "reset while running", + name: "reset-while-running", start: time.Unix(12345, 0), delay: 2 * time.Second, steps: []testStep{ @@ -2270,7 +2270,7 @@ func TestAfterFunc(t *testing.T) { }, }, { - name: "reset while stopped", + name: "reset-while-stopped", start: time.Unix(12345, 0), step: time.Second, delay: 2 * time.Second, @@ -2303,7 +2303,7 @@ func TestAfterFunc(t *testing.T) { }, }, { - name: "reset absolute", + name: "reset-absolute", start: time.Unix(12345, 0), step: time.Second, delay: 2 * time.Second, @@ -2333,7 +2333,7 @@ func TestAfterFunc(t *testing.T) { }, }, { - name: "follow real time", + name: "follow-real-time", realTimeOpts: new(ClockOpts), start: time.Unix(12345, 0), delay: 2 * time.Second, diff --git a/tstest/deptest/deptest.go b/tstest/deptest/deptest.go index 3117af2fffa01..59672761ef06b 100644 --- a/tstest/deptest/deptest.go +++ b/tstest/deptest/deptest.go @@ -124,7 +124,7 @@ func ImportAliasCheck(t testing.TB, relDir string) { } badRx := regexp.MustCompile(`^([^:]+:\d+):\s+"golang\.org/x/exp/(slices|maps)"`) if s := strings.TrimSpace(string(matches)); s != "" { - for _, line := range strings.Split(s, "\n") { + for line := range strings.SplitSeq(s, "\n") { if m := badRx.FindStringSubmatch(line); m != nil { t.Errorf("%s: the x/exp/%s package should be imported as x%s", m[1], m[2], m[2]) } diff --git a/tstest/integration/integration.go b/tstest/integration/integration.go index a98df81808097..861ec808d0206 100644 --- a/tstest/integration/integration.go +++ b/tstest/integration/integration.go @@ -73,7 +73,11 @@ type Binaries struct { // BinaryInfo describes a tailscale or tailscaled binary. type BinaryInfo struct { - Path string // abs path to tailscale or tailscaled binary + // Path is the absolute path to the tailscale or tailscaled binary. + // This path may become invalid after the owning test's TempDir is + // cleaned up; use FD (or Contents on Windows) to access the binary + // contents. + Path string Size int64 // FD and FDmu are set on Unix to efficiently copy the binary to a new @@ -88,16 +92,24 @@ type BinaryInfo struct { Contents []byte } +// CopyTo copies or hardlinks the binary into dir, returning a new BinaryInfo +// with an updated Path. The source bytes come from FD (or Contents on Windows), +// not from b.Path, which may have been deleted when its owning test's TempDir +// was cleaned up. func (b BinaryInfo) CopyTo(dir string) (BinaryInfo, error) { ret := b ret.Path = filepath.Join(dir, path.Base(b.Path)) switch runtime.GOOS { case "linux": - // TODO(bradfitz): be fancy and use linkat with AT_EMPTY_PATH to avoid - // copying? I couldn't get it to work, though. - // For now, just do the same thing as every other Unix and copy - // the binary. + // Try to hardlink from the open FD via /proc/self/fd, avoiding a + // full copy of the binary. We can't use os.Link(b.Path, ret.Path) + // because b.Path is in the first test's TempDir, which may be + // cleaned up before later tests call CopyTo. The open FD keeps the + // inode alive after the path is deleted. + if err := tryLinkat(b.FD, ret.Path); err == nil { + return ret, nil + } fallthrough case "darwin", "freebsd", "openbsd", "netbsd": f, err := os.OpenFile(ret.Path, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0o755) @@ -1044,6 +1056,9 @@ func (n *TestNode) Tailscale(arg ...string) *exec.Cmd { cmd.Env = append(os.Environ(), "TS_DEBUG_UP_FLAG_GOOS="+n.upFlagGOOS, "TS_LOGS_DIR="+n.env.t.TempDir(), + "SSH_CLIENT=", // Clear SSH_CLIENT to prevent isSSHOverTailscale() false positives in tests + "SSH_CONNECTION=", // just in case + "SSH_AUTH_SOCK=", // just in case ) if *verboseTailscale { cmd.Stdout = os.Stdout diff --git a/tstest/integration/integration_linkat_linux.go b/tstest/integration/integration_linkat_linux.go new file mode 100644 index 0000000000000..68e9075d94bde --- /dev/null +++ b/tstest/integration/integration_linkat_linux.go @@ -0,0 +1,24 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package integration + +import ( + "fmt" + "os" + + "golang.org/x/sys/unix" +) + +// tryLinkat attempts to hardlink the file referenced by fd to newpath, +// avoiding a full copy of the binary. It uses /proc/self/fd/ with +// AT_SYMLINK_FOLLOW, which works without elevated privileges (unlike +// AT_EMPTY_PATH which requires CAP_DAC_READ_SEARCH). +func tryLinkat(fd *os.File, newpath string) error { + procPath := fmt.Sprintf("/proc/self/fd/%d", fd.Fd()) + err := unix.Linkat(unix.AT_FDCWD, procPath, unix.AT_FDCWD, newpath, unix.AT_SYMLINK_FOLLOW) + if err != nil { + return fmt.Errorf("linkat via /proc/self/fd: %w", err) + } + return nil +} diff --git a/tstest/integration/integration_linkat_linux_test.go b/tstest/integration/integration_linkat_linux_test.go new file mode 100644 index 0000000000000..fc0a2873f68bd --- /dev/null +++ b/tstest/integration/integration_linkat_linux_test.go @@ -0,0 +1,48 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package integration + +import ( + "os" + "path/filepath" + "testing" + + "golang.org/x/sys/unix" +) + +func TestTryLinkat(t *testing.T) { + src := filepath.Join(t.TempDir(), "src") + if err := os.WriteFile(src, []byte("hello world"), 0o755); err != nil { + t.Fatal(err) + } + fd, err := os.Open(src) + if err != nil { + t.Fatal(err) + } + defer fd.Close() + + dst := filepath.Join(t.TempDir(), "dst") + if err := tryLinkat(fd, dst); err != nil { + t.Fatal(err) + } + + got, err := os.ReadFile(dst) + if err != nil { + t.Fatal(err) + } + if string(got) != "hello world" { + t.Fatalf("got %q, want %q", got, "hello world") + } + + var stSrc, stDst unix.Stat_t + if err := unix.Stat(src, &stSrc); err != nil { + t.Fatal(err) + } + if err := unix.Stat(dst, &stDst); err != nil { + t.Fatal(err) + } + if stSrc.Ino != stDst.Ino { + t.Fatalf("inodes differ: src=%d, dst=%d", stSrc.Ino, stDst.Ino) + } +} diff --git a/tstest/integration/integration_linkat_other.go b/tstest/integration/integration_linkat_other.go new file mode 100644 index 0000000000000..7e22ca0daa002 --- /dev/null +++ b/tstest/integration/integration_linkat_other.go @@ -0,0 +1,15 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux + +package integration + +import ( + "errors" + "os" +) + +func tryLinkat(_ *os.File, _ string) error { + return errors.New("linkat with AT_EMPTY_PATH not supported on this OS") +} diff --git a/tstest/integration/integration_test.go b/tstest/integration/integration_test.go index 779cba6290cfe..3064d6a26f96d 100644 --- a/tstest/integration/integration_test.go +++ b/tstest/integration/integration_test.go @@ -50,7 +50,6 @@ import ( "tailscale.com/types/key" "tailscale.com/types/netmap" "tailscale.com/types/opt" - "tailscale.com/types/ptr" "tailscale.com/util/must" "tailscale.com/util/set" ) @@ -74,9 +73,7 @@ func TestMain(m *testing.M) { // https://github.com/tailscale/tailscale/issues/7894 func TestTUNMode(t *testing.T) { tstest.Shard(t) - if os.Getuid() != 0 { - t.Skip("skipping when not root") - } + tstest.RequireRoot(t) tstest.Parallel(t) env := NewTestEnv(t) env.tunMode = true @@ -201,23 +198,34 @@ func TestExpectedFeaturesLinked(t *testing.T) { } func TestCollectPanic(t *testing.T) { - flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/15865") tstest.Shard(t) tstest.Parallel(t) env := NewTestEnv(t) n := NewTestNode(t, env) - cmd := exec.Command(env.daemon, "--cleanup") + // Wait for the binary to be executable, working around a + // mysterious ETXTBSY on GitHub Actions. + // See https://github.com/tailscale/tailscale/issues/15868. + if err := n.awaitTailscaledRunnable(); err != nil { + t.Fatal(err) + } + + logsDir := t.TempDir() + cmd := exec.Command(env.daemon, "--cleanup", "--statedir="+n.dir) cmd.Env = append(os.Environ(), "TS_PLEASE_PANIC=1", "TS_LOG_TARGET="+n.env.LogCatcherServer.URL, + "TS_LOGS_DIR="+logsDir, ) got, _ := cmd.CombinedOutput() // we expect it to fail, ignore err t.Logf("initial run: %s", got) // Now we run it again, and on start, it will upload the logs to logcatcher. - cmd = exec.Command(env.daemon, "--cleanup") - cmd.Env = append(os.Environ(), "TS_LOG_TARGET="+n.env.LogCatcherServer.URL) + cmd = exec.Command(env.daemon, "--cleanup", "--statedir="+n.dir) + cmd.Env = append(os.Environ(), + "TS_LOG_TARGET="+n.env.LogCatcherServer.URL, + "TS_LOGS_DIR="+logsDir, + ) if out, err := cmd.CombinedOutput(); err != nil { t.Fatalf("cleanup failed: %v: %q", err, out) } @@ -459,83 +467,70 @@ func TestOneNodeUpAuth(t *testing.T) { }, } { tstest.Shard(t) + t.Run(tt.name, func(t *testing.T) { + tstest.Parallel(t) + + env := NewTestEnv(t, ConfigureControl( + func(control *testcontrol.Server) { + if tt.authKey != "" { + control.RequireAuthKey = tt.authKey + } else { + control.RequireAuth = true + } - for _, useSeamlessKeyRenewal := range []bool{true, false} { - name := tt.name - if useSeamlessKeyRenewal { - name += "-with-seamless" - } - t.Run(name, func(t *testing.T) { - tstest.Parallel(t) - - env := NewTestEnv(t, ConfigureControl( - func(control *testcontrol.Server) { - if tt.authKey != "" { - control.RequireAuthKey = tt.authKey - } else { - control.RequireAuth = true - } - - if tt.requireDeviceApproval { - control.RequireMachineAuth = true - } - - control.AllNodesSameUser = true - - if useSeamlessKeyRenewal { - control.DefaultNodeCapabilities = &tailcfg.NodeCapMap{ - tailcfg.NodeAttrSeamlessKeyRenewal: []tailcfg.RawMessage{}, - } - } - }, - )) + if tt.requireDeviceApproval { + control.RequireMachineAuth = true + } - n1 := NewTestNode(t, env) - d1 := n1.StartDaemon() - defer d1.MustCleanShutdown(t) + control.AllNodesSameUser = true + }, + )) - for i, step := range tt.steps { - t.Logf("Running step %d", i) - cmdArgs := append(step.args, "--login-server="+env.ControlURL()) + n1 := NewTestNode(t, env) + d1 := n1.StartDaemon() + defer d1.MustCleanShutdown(t) - t.Logf("Running command: %s", strings.Join(cmdArgs, " ")) + for i, step := range tt.steps { + t.Logf("Running step %d", i) + cmdArgs := append(step.args, "--login-server="+env.ControlURL()) - var authURLCount atomic.Int32 - var deviceApprovalURLCount atomic.Int32 + t.Logf("Running command: %s", strings.Join(cmdArgs, " ")) - handler := &authURLParserWriter{t: t, - authURLFn: completeLogin(t, env.Control, &authURLCount), - deviceApprovalURLFn: completeDeviceApproval(t, n1, &deviceApprovalURLCount), - } + var authURLCount atomic.Int32 + var deviceApprovalURLCount atomic.Int32 - cmd := n1.Tailscale(cmdArgs...) - cmd.Stdout = handler - cmd.Stdout = handler - cmd.Stderr = cmd.Stdout - if err := cmd.Run(); err != nil { - t.Fatalf("up: %v", err) - } + handler := &authURLParserWriter{t: t, + authURLFn: completeLogin(t, env.Control, &authURLCount), + deviceApprovalURLFn: completeDeviceApproval(t, n1, &deviceApprovalURLCount), + } - n1.AwaitRunning() + cmd := n1.Tailscale(cmdArgs...) + cmd.Stdout = handler + cmd.Stdout = handler + cmd.Stderr = cmd.Stdout + if err := cmd.Run(); err != nil { + t.Fatalf("up: %v", err) + } - var wantAuthURLCount int32 - if step.wantAuthURL { - wantAuthURLCount = 1 - } - if n := authURLCount.Load(); n != wantAuthURLCount { - t.Errorf("Auth URLs completed = %d; want %d", n, wantAuthURLCount) - } + n1.AwaitRunning() - var wantDeviceApprovalURLCount int32 - if step.wantDeviceApprovalURL { - wantDeviceApprovalURLCount = 1 - } - if n := deviceApprovalURLCount.Load(); n != wantDeviceApprovalURLCount { - t.Errorf("Device approval URLs completed = %d; want %d", n, wantDeviceApprovalURLCount) - } + var wantAuthURLCount int32 + if step.wantAuthURL { + wantAuthURLCount = 1 } - }) - } + if n := authURLCount.Load(); n != wantAuthURLCount { + t.Errorf("Auth URLs completed = %d; want %d", n, wantAuthURLCount) + } + + var wantDeviceApprovalURLCount int32 + if step.wantDeviceApprovalURL { + wantDeviceApprovalURLCount = 1 + } + if n := deviceApprovalURLCount.Load(); n != wantDeviceApprovalURLCount { + t.Errorf("Device approval URLs completed = %d; want %d", n, wantDeviceApprovalURLCount) + } + } + }) } } @@ -730,8 +725,8 @@ func TestConfigFileAuthKey(t *testing.T) { must.Do(os.WriteFile(authKeyFile, fmt.Appendf(nil, "%s\n", authKey), 0666)) must.Do(os.WriteFile(n1.configFile, must.Get(json.Marshal(ipn.ConfigVAlpha{ Version: "alpha0", - AuthKey: ptr.To("file:" + authKeyFile), - ServerURL: ptr.To(n1.env.ControlServer.URL), + AuthKey: new("file:" + authKeyFile), + ServerURL: new(n1.env.ControlServer.URL), })), 0644)) d1 := n1.StartDaemon() @@ -1555,9 +1550,7 @@ func testAutoUpdateDefaults(t *testing.T, useCap bool) { // https://github.com/tailscale/corp/issues/22511 func TestDNSOverTCPIntervalResolver(t *testing.T) { tstest.Shard(t) - if os.Getuid() != 0 { - t.Skip("skipping when not root") - } + tstest.RequireRoot(t) env := NewTestEnv(t) env.tunMode = true n1 := NewTestNode(t, env) @@ -1627,9 +1620,7 @@ func TestDNSOverTCPIntervalResolver(t *testing.T) { // directions. func TestNetstackTCPLoopback(t *testing.T) { tstest.Shard(t) - if os.Getuid() != 0 { - t.Skip("skipping when not root") - } + tstest.RequireRoot(t) env := NewTestEnv(t) env.tunMode = true @@ -1674,7 +1665,7 @@ func TestNetstackTCPLoopback(t *testing.T) { defer lis.Close() writeFn := func(conn net.Conn) error { - for i := 0; i < writeBufIterations; i++ { + for range writeBufIterations { toWrite := make([]byte, writeBufSize) var wrote int for { @@ -1769,9 +1760,7 @@ func TestNetstackTCPLoopback(t *testing.T) { // directions. func TestNetstackUDPLoopback(t *testing.T) { tstest.Shard(t) - if os.Getuid() != 0 { - t.Skip("skipping when not root") - } + tstest.RequireRoot(t) env := NewTestEnv(t) env.tunMode = true @@ -2232,7 +2221,7 @@ func TestC2NDebugNetmap(t *testing.T) { // Send a delta update to n1, marking node 0 as online. env.Control.AddRawMapResponse(nodes[1].Key, &tailcfg.MapResponse{ PeersChangedPatch: []*tailcfg.PeerChange{{ - NodeID: nodes[0].ID, Online: ptr.To(true), + NodeID: nodes[0].ID, Online: new(true), }}, }) @@ -2364,6 +2353,38 @@ func TestTailnetLock(t *testing.T) { t.Fatalf("ping node3 -> signing1: expected success, got err: %v", err) } }) + + // If you run `tailscale lock (add|remove|revoke-keys)` but don't pass any keys, + // we print a helpful error message. + // + // Regression test for tailscale/tailscale#19130 + t.Run("no-keys-is-error", func(t *testing.T) { + for _, verb := range []string{"add", "remove", "revoke-keys"} { + t.Run(verb, func(t *testing.T) { + tstest.Shard(t) + t.Parallel() + + env := NewTestEnv(t) + n1 := NewTestNode(t, env) + d1 := n1.StartDaemon() + defer d1.MustCleanShutdown(t) + + n1.MustUp() + n1.AwaitRunning() + + revokeCmd := n1.Tailscale("lock", verb) + out, err := revokeCmd.CombinedOutput() + if err == nil { + t.Fatal("expected command to fail, but succeeded") + } + want := "missing argument" + got := string(out) + if !strings.Contains(string(out), want) { + t.Fatalf("expected output to contain %q, got %q", want, got) + } + }) + } + }) } func TestNodeWithBadStateFile(t *testing.T) { diff --git a/tstest/integration/nat/nat_test.go b/tstest/integration/nat/nat_test.go index 2322e243a8ee9..8eca5742f28d8 100644 --- a/tstest/integration/nat/nat_test.go +++ b/tstest/integration/nat/nat_test.go @@ -7,6 +7,7 @@ import ( "bytes" "cmp" "context" + "encoding/json" "errors" "flag" "fmt" @@ -17,6 +18,8 @@ import ( "os" "os/exec" "path/filepath" + "runtime" + "strconv" "strings" "sync" "testing" @@ -74,7 +77,7 @@ func newNatTest(tb testing.TB) *natTest { cmd.Stderr = os.Stderr cmd.Stdout = os.Stdout if err := cmd.Run(); err != nil { - tb.Fatalf("Error running 'make natlab' in gokrazy directory") + tb.Fatalf("Error running 'make natlab' in gokrazy directory: %v", err) } if _, err := os.Stat(nt.base); err != nil { tb.Skipf("still can't find VM image: %v", err) @@ -133,6 +136,24 @@ func easyAnd6(c *vnet.Config) *vnet.Node { vnet.EasyNAT)) } +// easyNoControlDiscoRotate sets up a node with easy NAT, cuts traffic to +// control after connecting, and then rotates the disco key to simulate a newly +// started node (from a disco perspective). +func easyNoControlDiscoRotate(c *vnet.Config) *vnet.Node { + n := c.NumNodes() + 1 + nw := c.AddNetwork( + fmt.Sprintf("2.%d.%d.%d", n, n, n), // public IP + fmt.Sprintf("192.168.%d.1/24", n), + vnet.EasyNAT) + nw.SetPostConnectControlBlackhole(true) + return c.AddNode( + vnet.TailscaledEnv{ + Key: "TS_USE_CACHED_NETMAP", + Value: "true", + }, + vnet.RotateDisco, vnet.PreICMPPing, nw) +} + func v6AndBlackholedIPv4(c *vnet.Config) *vnet.Node { n := c.NumNodes() + 1 nw := c.AddNetwork( @@ -175,6 +196,22 @@ func sameLAN(c *vnet.Config) *vnet.Node { return c.AddNode(nw) } +func sameLANNoDropCGNAT(c *vnet.Config) *vnet.Node { + nw := c.FirstNetwork() + if nw == nil { + return nil + } + if !nw.CanTakeMoreNodes() { + return nil + } + return c.AddNode( + nw, + tailcfg.NodeCapMap{ + tailcfg.NodeAttrDisableLinuxCGNATDropRule: nil, + }, + ) +} + func one2one(c *vnet.Config) *vnet.Node { n := c.NumNodes() + 1 return c.AddNode(c.AddNetwork( @@ -259,7 +296,7 @@ func hardPMP(c *vnet.Config) *vnet.Node { fmt.Sprintf("10.7.%d.1/24", n), vnet.HardNAT, vnet.NATPMP)) } -func (nt *natTest) runTest(addNode ...addNodeFunc) pingRoute { +func (nt *natTest) setupTest(ctx context.Context, addNode ...addNodeFunc) (nodes []*vnet.Node, clients []*vnet.NodeAgentClient, cleanup func()) { if len(addNode) < 1 || len(addNode) > 2 { nt.tb.Fatalf("runTest: invalid number of nodes %v; want 1 or 2", len(addNode)) } @@ -267,7 +304,6 @@ func (nt *natTest) runTest(addNode ...addNodeFunc) pingRoute { var c vnet.Config c.SetPCAPFile(*pcapFile) - nodes := []*vnet.Node{} for _, fn := range addNode { node := fn(&c) if node == nil { @@ -298,9 +334,7 @@ func (nt *natTest) runTest(addNode ...addNodeFunc) pingRoute { } defer srv.Close() - wg.Add(1) - go func() { - defer wg.Done() + wg.Go(func() { for { c, err := srv.Accept() if err != nil { @@ -308,8 +342,17 @@ func (nt *natTest) runTest(addNode ...addNodeFunc) pingRoute { } go nt.vnet.ServeUnixConn(c.(*net.UnixConn), vnet.ProtocolQEMU) } - }() + }) + + haveKVM := false + if runtime.GOOS == "linux" { + if f, err := os.OpenFile("/dev/kvm", os.O_RDWR, 0); err == nil { + f.Close() + haveKVM = true + } + } + qmpSocks := make([]string, len(nodes)) for i, node := range nodes { disk := fmt.Sprintf("%s/node-%d.qcow2", nt.tempDir, i) out, err := exec.Command("qemu-img", "create", @@ -332,22 +375,28 @@ func (nt *natTest) runTest(addNode ...addNodeFunc) pingRoute { } envStr := envBuf.String() - cmd := exec.Command("qemu-system-x86_64", + qmpSocks[i] = fmt.Sprintf("%s/qmp-node-%d.sock", nt.tempDir, i) + qemuArgs := []string{ "-M", "microvm,isa-serial=off", "-m", "384M", "-nodefaults", "-no-user-config", "-nographic", "-kernel", nt.kernel, - "-append", "console=hvc0 root=PARTUUID=60c24cc1-f3f9-427a-8199-76baa2d60001/PARTNROFF=1 ro init=/gokrazy/init panic=10 oops=panic pci=off nousb tsc=unstable clocksource=hpet gokrazy.remote_syslog.target="+sysLogAddr+" tailscale-tta=1"+envStr, - "-drive", "id=blk0,file="+disk+",format=qcow2", + "-append", "console=hvc0 root=PARTUUID=60c24cc1-f3f9-427a-8199-76baa2d60001/PARTNROFF=1 ro init=/gokrazy/init panic=10 oops=panic pci=off nousb gokrazy.remote_syslog.target=" + sysLogAddr + " tailscale-tta=1" + envStr, + "-drive", "id=blk0,file=" + disk + ",format=qcow2", "-device", "virtio-blk-device,drive=blk0", - "-netdev", "stream,id=net0,addr.type=unix,addr.path="+sockAddr, + "-netdev", "stream,id=net0,addr.type=unix,addr.path=" + sockAddr, "-device", "virtio-serial-device", "-device", "virtio-rng-device", - "-device", "virtio-net-device,netdev=net0,mac="+node.MAC().String(), + "-device", "virtio-net-device,netdev=net0,mac=" + node.MAC().String(), "-chardev", "stdio,id=virtiocon0,mux=on", "-device", "virtconsole,chardev=virtiocon0", "-mon", "chardev=virtiocon0,mode=readline", - ) + "-qmp", "unix:" + qmpSocks[i] + ",server=on,wait=off", + } + if haveKVM { + qemuArgs = append(qemuArgs, "-enable-kvm", "-cpu", "host") + } + cmd := exec.Command("qemu-system-x86_64", qemuArgs...) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr if err := cmd.Start(); err != nil { @@ -359,18 +408,23 @@ func (nt *natTest) runTest(addNode ...addNodeFunc) pingRoute { }) } - ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) - defer cancel() + for i, node := range nodes { + if err := nt.vnet.AwaitFirstPacket(ctx, node.MAC()); err != nil { + t.Logf("node %v: no boot progress (no packets received): %v", node, err) + t.Logf("node %v: QMP status: %s", node, qmpQueryStatus(qmpSocks[i])) + t.FailNow() + } + t.Logf("node %v: boot detected (first packet received)", node) + } - var clients []*vnet.NodeAgentClient for _, n := range nodes { - clients = append(clients, nt.vnet.NodeAgentClient(n)) + client := nt.vnet.NodeAgentClient(n) + n.SetClient(client) + clients = append(clients, client) } - sts := make([]*ipnstate.Status, len(nodes)) var eg errgroup.Group for i, c := range clients { - i, c := i, c eg.Go(func() error { node := nodes[i] t.Logf("%v calling Status...", node) @@ -387,21 +441,31 @@ func (nt *natTest) runTest(addNode ...addNodeFunc) pingRoute { t.Logf("%v firewalled", node) } - if err := up(ctx, c); err != nil { - return fmt.Errorf("%v up: %w", node, err) - } - t.Logf("%v up!", node) + if node.ShouldJoinTailnet() { + if err := up(ctx, c); err != nil { + return fmt.Errorf("%v up: %w", node, err) + } + t.Logf("%v up!", node) - st, err = c.Status(ctx) - if err != nil { - return fmt.Errorf("%v status: %w", node, err) - } - sts[i] = st + st, err = c.Status(ctx) + if err != nil { + return fmt.Errorf("%v status: %w", node, err) + } + + if capMap := node.WantCapMap(); capMap != nil { + nt.tb.Logf("using capmap for %s: %+v", node.String(), capMap) + nt.vnet.ControlServer().SetNodeCapMap(st.Self.PublicKey, capMap) + } + + if st.BackendState != "Running" { + return fmt.Errorf("%v state = %q", node, st.BackendState) + } - if st.BackendState != "Running" { - return fmt.Errorf("%v state = %q", node, st.BackendState) + t.Logf("%v AllowedIPs: %v", node, st.Self.Addrs) + t.Logf("%v up with %v", node, st.Self.TailscaleIPs) + } else { + t.Logf("%v skipping joining tailnet", node) } - t.Logf("%v up with %v", node, sts[i].Self.TailscaleIPs) return nil }) } @@ -409,15 +473,109 @@ func (nt *natTest) runTest(addNode ...addNodeFunc) pingRoute { t.Fatalf("initial setup: %v", err) } - defer nt.vnet.Close() + return nodes, clients, nt.vnet.Close +} + +type hasDeadline interface { + Deadline() (deadline time.Time, ok bool) +} + +// testContext returns a context derived from the test's deadline (from -timeout), +// leaving a small margin for cleanup. Falls back to 60s if no deadline is set. +func testContext(tb testing.TB) (context.Context, context.CancelFunc) { + if t, ok := tb.(hasDeadline); ok { + if dl, ok := t.Deadline(); ok { + const margin = 5 * time.Second + return context.WithDeadline(context.Background(), dl.Add(-margin)) + } + } + return context.WithTimeout(context.Background(), 60*time.Second) +} + +func (nt *natTest) runHostConnectivityTest(addNode ...addNodeFunc) bool { + ctx, cancel := testContext(nt.tb) + defer cancel() + nodes, clients, cleanup := nt.setupTest(ctx, addNode...) + defer cleanup() + + if len(nodes) != 2 { + nt.tb.Logf("ping can only be done among exactly two nodes") + return false + } + var fromClient, toClient *vnet.NodeAgentClient + for i, n := range nodes { + if n.ShouldJoinTailnet() && fromClient == nil { + fromClient = clients[i] + } else { + toClient = clients[i] + } + } + got, err := sendHostNetworkPing(ctx, nt.tb, fromClient, toClient) + if err != nil { + nt.tb.Fatalf("ping host: %v", err) + } + nt.tb.Logf("ping success: %v", got) + return got +} + +func (nt *natTest) runTailscaleConnectivityTest(addNode ...addNodeFunc) pingRoute { + ctx, cancel := testContext(nt.tb) + defer cancel() + + nodes, clients, cleanup := nt.setupTest(ctx, addNode...) + defer cleanup() + t := nt.tb if len(nodes) < 2 { return "" } + for _, n := range nodes { + if !n.ShouldJoinTailnet() { + t.Logf("%v did not join tailnet", n) + return "" + } + } - pingRes, err := ping(ctx, t, clients[0], sts[1].Self.TailscaleIPs[0]) + sts := make([]*ipnstate.Status, len(nodes)) + var eg errgroup.Group + for i, c := range clients { + eg.Go(func() error { + node := nodes[i] + st, err := c.Status(ctx) + if err != nil { + return fmt.Errorf("%v: %w", node, err) + } + sts[i] = st + return nil + }) + } + if err := eg.Wait(); err != nil { + t.Fatalf("get node statuses: %v", err) + } + + preICMPPing := false + for _, node := range nodes { + node.Network().PostConnectedToControl() + if err := node.PostConnectedToControl(ctx); err != nil { + t.Fatalf("post control error: %s", err) + } + if node.PreICMPPing() { + preICMPPing = true + } + } + + // Should we send traffic across the nodes before starting disco? + // For nodes that rotated disco keys after control going away. + if preICMPPing { + _, err := ping(ctx, t, clients[0], sts[1].Self.TailscaleIPs[0], tailcfg.PingICMP) + if err != nil { + t.Fatalf("ICMP ping failure: %v", err) + } + } + + pingRes, err := ping(ctx, t, clients[0], sts[1].Self.TailscaleIPs[0], tailcfg.PingDisco) if err != nil { - t.Fatalf("ping failure: %v", err) + t.Logf("ping failure: %v", err) } nt.gotRoute = classifyPing(pingRes) t.Logf("ping route: %v", nt.gotRoute) @@ -450,12 +608,12 @@ const ( routeNil pingRoute = "nil" // *ipnstate.PingResult is nil ) -func ping(ctx context.Context, t testing.TB, c *vnet.NodeAgentClient, target netip.Addr) (*ipnstate.PingResult, error) { +func ping(ctx context.Context, t testing.TB, c *vnet.NodeAgentClient, target netip.Addr, pType tailcfg.PingType) (*ipnstate.PingResult, error) { var lastRes *ipnstate.PingResult for n := range 10 { t.Logf("ping attempt %d to %v ...", n+1, target) pingCtx, cancel := context.WithTimeout(ctx, 2*time.Second) - pr, err := c.PingWithOpts(pingCtx, target, tailcfg.PingDisco, tailscale.PingOpts{}) + pr, err := c.PingWithOpts(pingCtx, target, pType, tailscale.PingOpts{}) cancel() if err != nil { t.Logf("ping attempt %d error: %v", n+1, err) @@ -484,6 +642,55 @@ func ping(ctx context.Context, t testing.TB, c *vnet.NodeAgentClient, target net return nil, fmt.Errorf("no ping response (ctx: %v)", ctx.Err()) } +// qmpQueryStatus connects to a QEMU QMP socket and returns the VM status +// (e.g. "running", "paused", "prelaunch") or an error string. +func qmpQueryStatus(sockPath string) string { + conn, err := net.DialTimeout("unix", sockPath, 2*time.Second) + if err != nil { + return fmt.Sprintf("dial error: %v", err) + } + defer conn.Close() + conn.SetDeadline(time.Now().Add(5 * time.Second)) + dec := json.NewDecoder(conn) + + // Read QMP greeting. + var greeting json.RawMessage + if err := dec.Decode(&greeting); err != nil { + return fmt.Sprintf("greeting error: %v", err) + } + + // Enter command mode. + if _, err := conn.Write([]byte(`{"execute":"qmp_capabilities"}` + "\n")); err != nil { + return fmt.Sprintf("write caps: %v", err) + } + var capsResp json.RawMessage + if err := dec.Decode(&capsResp); err != nil { + return fmt.Sprintf("caps response: %v", err) + } + + // Query status. + if _, err := conn.Write([]byte(`{"execute":"query-status"}` + "\n")); err != nil { + return fmt.Sprintf("write query-status: %v", err) + } + var statusResp struct { + Return struct { + Running bool `json:"running"` + Status string `json:"status"` + } `json:"return"` + Error *struct { + Class string `json:"class"` + Desc string `json:"desc"` + } `json:"error"` + } + if err := dec.Decode(&statusResp); err != nil { + return fmt.Sprintf("status response: %v", err) + } + if statusResp.Error != nil { + return fmt.Sprintf("qmp error: %s: %s", statusResp.Error.Class, statusResp.Error.Desc) + } + return fmt.Sprintf("status=%s running=%v", statusResp.Return.Status, statusResp.Return.Running) +} + func up(ctx context.Context, c *vnet.NodeAgentClient) error { req, err := http.NewRequestWithContext(ctx, "GET", "http://unused/up", nil) if err != nil { @@ -501,6 +708,60 @@ func up(ctx context.Context, c *vnet.NodeAgentClient) error { return nil } +func getClientIP(ctx context.Context, c *vnet.NodeAgentClient) (netip.Addr, error) { + getIPReq, err := http.NewRequestWithContext(ctx, "GET", "http://unused/ip", nil) + if err != nil { + return netip.Addr{}, err + } + res, err := c.HTTPClient.Do(getIPReq) + if err != nil { + return netip.Addr{}, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return netip.Addr{}, fmt.Errorf("client returned http status %q", res.Status) + } + ipBytes, err := io.ReadAll(res.Body) + if err != nil { + return netip.Addr{}, err + } + addrPort, err := netip.ParseAddrPort(string(ipBytes)) + if err != nil { + return netip.Addr{}, err + } + return addrPort.Addr(), nil +} + +// sendHostNetworkPing pings toClient from fromClient, and returns whether +// toClient responded to the ping. +func sendHostNetworkPing(ctx context.Context, tb testing.TB, fromClient, toClient *vnet.NodeAgentClient) (bool, error) { + toIP, err := getClientIP(ctx, toClient) + if err != nil { + return false, fmt.Errorf("get ip: %w", err) + } + req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("http://unused/ping?host=%s", toIP.String()), nil) + if err != nil { + return false, err + } + res, err := fromClient.HTTPClient.Do(req) + if err != nil { + return false, err + } + defer res.Body.Close() + got, err := io.ReadAll(res.Body) + if err != nil { + tb.Logf("error while reading http body: %v", err) + } else { + tb.Logf("got response from ping: %q", got) + } + ec, err := strconv.Atoi(res.Header.Get("Exec-Exit-Code")) + if err != nil { + return false, fmt.Errorf("parse exit code: %w", err) + } + tb.Logf("got ec: %v", ec) + return ec == 0, nil +} + type nodeType struct { name string fn addNodeFunc @@ -514,6 +775,7 @@ var types = []nodeType{ {"hardPMP", hardPMP}, {"one2one", one2one}, {"sameLAN", sameLAN}, + {"cgnat", cgnatNoTailnet}, } // want sets the expected ping route for the test. @@ -525,10 +787,37 @@ func (nt *natTest) want(r pingRoute) { func TestEasyEasy(t *testing.T) { nt := newNatTest(t) - nt.runTest(easy, easy) + nt.runTailscaleConnectivityTest(easy, easy) nt.want(routeDirect) } +// TestTwoEasyNoControlDiscoRotate tests a situation where two nodes have been +// online and connected through control, but then loose control access and also +// rotate keys. It is not a perfect proxy for a cached node, as the node will +// still have a mapState and not use the backup method of inserting keys into +// the engine directly. +func TestTwoEasyNoControlDiscoRotate(t *testing.T) { + nt := newNatTest(t) + nt.runTailscaleConnectivityTest(easyNoControlDiscoRotate, easyNoControlDiscoRotate) + nt.want(routeDirect) +} + +func cgnatNoTailnet(c *vnet.Config) *vnet.Node { + n := c.NumNodes() + 1 + return c.AddNode(c.AddNetwork( + fmt.Sprintf("100.65.%d.1/16", n), + fmt.Sprintf("2.%d.%d.%d", n, n, n), // public IP + vnet.EasyNAT), + vnet.DontJoinTailnet) +} + +func TestNonTailscaleCGNATEndpoint(t *testing.T) { + nt := newNatTest(t) + if !nt.runHostConnectivityTest(cgnatNoTailnet, sameLANNoDropCGNAT) { + t.Fatalf("could not ping") + } +} + // Issue tailscale/corp#26438: use learned DERP route as send path of last // resort // @@ -545,13 +834,13 @@ func TestEasyEasy(t *testing.T) { // packet over a particular DERP from that peer. func TestFallbackDERPRegionForPeer(t *testing.T) { nt := newNatTest(t) - nt.runTest(hard, hardNoDERPOrEndoints) + nt.runTailscaleConnectivityTest(hard, hardNoDERPOrEndoints) nt.want(routeDERP) } func TestSingleJustIPv6(t *testing.T) { nt := newNatTest(t) - nt.runTest(just6) + nt.runTailscaleConnectivityTest(just6) } var knownBroken = flag.Bool("known-broken", false, "run known-broken tests") @@ -565,24 +854,24 @@ func TestSingleDualBrokenIPv4(t *testing.T) { t.Skip("skipping known-broken test; set --known-broken to run; see https://github.com/tailscale/tailscale/issues/13346") } nt := newNatTest(t) - nt.runTest(v6AndBlackholedIPv4) + nt.runTailscaleConnectivityTest(v6AndBlackholedIPv4) } func TestJustIPv6(t *testing.T) { nt := newNatTest(t) - nt.runTest(just6, just6) + nt.runTailscaleConnectivityTest(just6, just6) nt.want(routeDirect) } func TestEasy4AndJust6(t *testing.T) { nt := newNatTest(t) - nt.runTest(easyAnd6, just6) + nt.runTailscaleConnectivityTest(easyAnd6, just6) nt.want(routeDirect) } func TestSameLAN(t *testing.T) { nt := newNatTest(t) - nt.runTest(easy, sameLAN) + nt.runTailscaleConnectivityTest(easy, sameLAN) nt.want(routeLocal) } @@ -592,25 +881,25 @@ func TestSameLAN(t *testing.T) { // * client machine has a stateful host firewall (e.g. ufw) func TestBPFDisco(t *testing.T) { nt := newNatTest(t) - nt.runTest(easyPMPFWPlusBPF, hard) + nt.runTailscaleConnectivityTest(easyPMPFWPlusBPF, hard) nt.want(routeDirect) } func TestHostFWNoBPF(t *testing.T) { nt := newNatTest(t) - nt.runTest(easyPMPFWNoBPF, hard) + nt.runTailscaleConnectivityTest(easyPMPFWNoBPF, hard) nt.want(routeDERP) } func TestHostFWPair(t *testing.T) { nt := newNatTest(t) - nt.runTest(easyFW, easyFW) + nt.runTailscaleConnectivityTest(easyFW, easyFW) nt.want(routeDirect) } func TestOneHostFW(t *testing.T) { nt := newNatTest(t) - nt.runTest(easy, easyFW) + nt.runTailscaleConnectivityTest(easy, easyFW) nt.want(routeDirect) } @@ -632,7 +921,7 @@ func TestPair(t *testing.T) { } nt := newNatTest(t) - nt.runTest(find(t1), find(t2)) + nt.runTailscaleConnectivityTest(find(t1), find(t2)) } var runGrid = flag.Bool("run-grid", false, "run grid test") @@ -668,7 +957,7 @@ func TestGrid(t *testing.T) { if route == "" { nt := newNatTest(t) - route = nt.runTest(a.fn, b.fn) + route = nt.runTailscaleConnectivityTest(a.fn, b.fn) if err := os.WriteFile(filename, []byte(string(route)), 0666); err != nil { t.Fatalf("writeFile: %v", err) } diff --git a/tstest/integration/tailscaled_deps_test_darwin.go b/tstest/integration/tailscaled_deps_test_darwin.go index 112f04767c89d..70e0d75faf3eb 100644 --- a/tstest/integration/tailscaled_deps_test_darwin.go +++ b/tstest/integration/tailscaled_deps_test_darwin.go @@ -20,6 +20,7 @@ import ( _ "tailscale.com/feature" _ "tailscale.com/feature/buildfeatures" _ "tailscale.com/feature/condregister" + _ "tailscale.com/feature/ssh" _ "tailscale.com/health" _ "tailscale.com/hostinfo" _ "tailscale.com/ipn" @@ -40,7 +41,6 @@ import ( _ "tailscale.com/net/tstun" _ "tailscale.com/paths" _ "tailscale.com/safesocket" - _ "tailscale.com/ssh/tailssh" _ "tailscale.com/syncs" _ "tailscale.com/tailcfg" _ "tailscale.com/tsd" diff --git a/tstest/integration/tailscaled_deps_test_freebsd.go b/tstest/integration/tailscaled_deps_test_freebsd.go index 112f04767c89d..70e0d75faf3eb 100644 --- a/tstest/integration/tailscaled_deps_test_freebsd.go +++ b/tstest/integration/tailscaled_deps_test_freebsd.go @@ -20,6 +20,7 @@ import ( _ "tailscale.com/feature" _ "tailscale.com/feature/buildfeatures" _ "tailscale.com/feature/condregister" + _ "tailscale.com/feature/ssh" _ "tailscale.com/health" _ "tailscale.com/hostinfo" _ "tailscale.com/ipn" @@ -40,7 +41,6 @@ import ( _ "tailscale.com/net/tstun" _ "tailscale.com/paths" _ "tailscale.com/safesocket" - _ "tailscale.com/ssh/tailssh" _ "tailscale.com/syncs" _ "tailscale.com/tailcfg" _ "tailscale.com/tsd" diff --git a/tstest/integration/tailscaled_deps_test_linux.go b/tstest/integration/tailscaled_deps_test_linux.go index 112f04767c89d..70e0d75faf3eb 100644 --- a/tstest/integration/tailscaled_deps_test_linux.go +++ b/tstest/integration/tailscaled_deps_test_linux.go @@ -20,6 +20,7 @@ import ( _ "tailscale.com/feature" _ "tailscale.com/feature/buildfeatures" _ "tailscale.com/feature/condregister" + _ "tailscale.com/feature/ssh" _ "tailscale.com/health" _ "tailscale.com/hostinfo" _ "tailscale.com/ipn" @@ -40,7 +41,6 @@ import ( _ "tailscale.com/net/tstun" _ "tailscale.com/paths" _ "tailscale.com/safesocket" - _ "tailscale.com/ssh/tailssh" _ "tailscale.com/syncs" _ "tailscale.com/tailcfg" _ "tailscale.com/tsd" diff --git a/tstest/integration/tailscaled_deps_test_openbsd.go b/tstest/integration/tailscaled_deps_test_openbsd.go index 112f04767c89d..70e0d75faf3eb 100644 --- a/tstest/integration/tailscaled_deps_test_openbsd.go +++ b/tstest/integration/tailscaled_deps_test_openbsd.go @@ -20,6 +20,7 @@ import ( _ "tailscale.com/feature" _ "tailscale.com/feature/buildfeatures" _ "tailscale.com/feature/condregister" + _ "tailscale.com/feature/ssh" _ "tailscale.com/health" _ "tailscale.com/hostinfo" _ "tailscale.com/ipn" @@ -40,7 +41,6 @@ import ( _ "tailscale.com/net/tstun" _ "tailscale.com/paths" _ "tailscale.com/safesocket" - _ "tailscale.com/ssh/tailssh" _ "tailscale.com/syncs" _ "tailscale.com/tailcfg" _ "tailscale.com/tsd" diff --git a/tstest/integration/tailscaled_deps_test_windows.go b/tstest/integration/tailscaled_deps_test_windows.go index cabac744a5c6c..00768c99e79c1 100644 --- a/tstest/integration/tailscaled_deps_test_windows.go +++ b/tstest/integration/tailscaled_deps_test_windows.go @@ -33,7 +33,6 @@ import ( _ "tailscale.com/ipn" _ "tailscale.com/ipn/auditlog" _ "tailscale.com/ipn/conffile" - _ "tailscale.com/ipn/desktop" _ "tailscale.com/ipn/ipnlocal" _ "tailscale.com/ipn/ipnserver" _ "tailscale.com/ipn/store" diff --git a/tstest/integration/testcontrol/testcontrol.go b/tstest/integration/testcontrol/testcontrol.go index 1e24414903ae9..c96b1ed33a126 100644 --- a/tstest/integration/testcontrol/testcontrol.go +++ b/tstest/integration/testcontrol/testcontrol.go @@ -38,7 +38,6 @@ import ( "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/opt" - "tailscale.com/types/ptr" "tailscale.com/util/httpm" "tailscale.com/util/mak" "tailscale.com/util/must" @@ -70,6 +69,22 @@ type Server struct { // belong to the same user. AllNodesSameUser bool + // AllOnline, if true, marks every peer entry in MapResponses as + // Online=true. This is a coarse stand-in for the per-node + // online/offline tracking that production control servers do based + // on streaming map sessions: certain disco-key handling fast paths + // in [tailscale.com/control/controlclient] and + // [tailscale.com/wgengine/userspace] only fire when the peer is + // reported online, so without this flag they are silently skipped + // in tests, which can mask bugs and slow down recovery from disco + // rotations. See [tailscale.com/control/controlclient/map.go] + // removeUnwantedDiscoUpdates and + // removeUnwantedDiscoUpdatesFromFullNetmapUpdate for callers that + // branch on Online. + // + // Finer-grained per-node online tracking can be added later. + AllOnline bool + // DefaultNodeCapabilities overrides the capability map sent to each client. DefaultNodeCapabilities *tailcfg.NodeCapMap @@ -81,10 +96,18 @@ type Server struct { ExplicitBaseURL string // e.g. "http://127.0.0.1:1234" with no trailing URL HTTPTestServer *httptest.Server // if non-nil, used to get BaseURL + // MaybeRateLimitRegister, if non-nil, is called before processing + // register requests. If it returns true, a 429 response is sent + // with the given Retry-After header value and body string. + MaybeRateLimitRegister func() (reject bool, retryAfter string, msg string) + // ModifyFirstMapResponse, if non-nil, is called exactly once per // MapResponse stream to modify the first MapResponse sent in response to it. ModifyFirstMapResponse func(*tailcfg.MapResponse, *tailcfg.MapRequest) + // AltMapStream, if non-nil, takes over serveMap. See [AltMapStreamFunc]. + AltMapStream AltMapStreamFunc + initMuxOnce sync.Once mux *http.ServeMux @@ -133,12 +156,16 @@ type Server struct { updates map[tailcfg.NodeID]chan updateType authPath map[string]*AuthPath nodeKeyAuthed set.Set[key.NodePublic] - msgToSend map[key.NodePublic]any // value is *tailcfg.PingRequest or entire *tailcfg.MapResponse - allExpired bool // All nodes will be told their node key is expired. + msgToSend map[key.NodePublic][]any // FIFO queue per node; values are *tailcfg.PingRequest or *tailcfg.MapResponse + allExpired bool // All nodes will be told their node key is expired. // tkaStorage records the Tailnet Lock state, if any. // If nil, Tailnet Lock is not enabled in the Tailnet. tkaStorage tka.CompactableChonk + + // onMapRequest, if non-nil, is called at the start of each map poll request. + // It can be used in tests to panic or fail if a node contacts control unexpectedly. + onMapRequest func(nodeKey key.NodePublic) } // BaseURL returns the server's base URL, without trailing slash. @@ -277,14 +304,16 @@ func (s *Server) AddRawMapResponse(nodeKeyDst key.NodePublic, mr *tailcfg.MapRes func (s *Server) addDebugMessage(nodeKeyDst key.NodePublic, msg any) bool { s.mu.Lock() defer s.mu.Unlock() - if s.msgToSend == nil { - s.msgToSend = map[key.NodePublic]any{} - } - // Now send the update to the channel node := s.nodeLocked(nodeKeyDst) if node == nil { return false } + updatesCh := s.updates[node.ID] + if updatesCh == nil { + // No streaming poll is registered, so there's nobody to deliver + // the message to. + return false + } if _, ok := msg.(*tailcfg.MapResponse); ok { if s.suppressAutoMapResponses == nil { @@ -293,10 +322,14 @@ func (s *Server) addDebugMessage(nodeKeyDst key.NodePublic, msg any) bool { s.suppressAutoMapResponses.Add(nodeKeyDst) } - s.msgToSend[nodeKeyDst] = msg - nodeID := node.ID - oldUpdatesCh := s.updates[nodeID] - return sendUpdate(oldUpdatesCh, updateDebugInjection) + mak.Set(&s.msgToSend, nodeKeyDst, append(s.msgToSend[nodeKeyDst], msg)) + // sendUpdate returning false here is fine: the channel is a lossy + // wake-up signal whose buffer is single-slot. A full buffer means a + // prior wake-up is still pending, and the streaming poll will check + // msgToSend when it processes that wake-up. The queue in msgToSend + // is the source of truth. + sendUpdate(updatesCh, updateDebugInjection) + return true } // Mark the Node key of every node as expired @@ -487,6 +520,13 @@ func (s *Server) SetSubnetRoutes(nodeKey key.NodePublic, routes []netip.Prefix) mak.Set(&s.nodeSubnetRoutes, nodeKey, routes) if node, ok := s.nodes[nodeKey]; ok { sendUpdate(s.updates[node.ID], updateSelfChanged) + // Also notify all other peers so they get the updated AllowedIPs + // in their next MapResponse. + for _, n := range s.nodes { + if n.ID != node.ID { + sendUpdate(s.updates[n.ID], updatePeerChanged) + } + } } } @@ -762,6 +802,16 @@ func (s *Server) CompleteDeviceApproval(controlUrl string, urlStr string, nodeKe } func (s *Server) serveRegister(w http.ResponseWriter, r *http.Request, mkey key.MachinePublic) { + if fn := s.MaybeRateLimitRegister; fn != nil { + if reject, retryAfter, msg := fn(); reject { + if retryAfter != "" { + w.Header().Set("Retry-After", retryAfter) + } + http.Error(w, msg, http.StatusTooManyRequests) + return + } + } + msg, err := io.ReadAll(io.LimitReader(r.Body, msgLimit)) r.Body.Close() if err != nil { @@ -1123,6 +1173,21 @@ func (s *Server) serveMap(w http.ResponseWriter, r *http.Request, mkey key.Machi go panic(fmt.Sprintf("bad map request: %v", err)) } + s.mu.Lock() + if s.onMapRequest != nil { + s.onMapRequest(req.NodeKey) + } + s.mu.Unlock() + + if s.AltMapStream != nil { + // The caller takes over the stream entirely; it must handle + // keeping the HTTP response alive until ctx is done. + compress := req.Compress != "" + w.WriteHeader(200) + s.AltMapStream(ctx, &mapStreamSender{s: s, w: w, compress: compress}, req) + return + } + jitter := rand.N(8 * time.Second) keepAlive := 50*time.Second + jitter @@ -1136,8 +1201,15 @@ func (s *Server) serveMap(w http.ResponseWriter, r *http.Request, mkey key.Machi return } + // Per tailcfg.MapRequest.Stream docs: if Stream is true and Version >= 68, + // the server must treat this as read-only and ignore Hostinfo, Endpoints, + // DiscoKey, etc. — modern clients send those via a separate non-streaming + // POST /machine/map from a dedicated updateRoutine, not piggybacked on the + // streaming poll. Without this, the streaming MapRequest's zero-valued + // DiscoKey/Endpoints clobber whatever was just pushed out-of-band. + streamingNonUpdate := req.Stream && req.Version >= 68 var peersToUpdate []tailcfg.NodeID - if !req.ReadOnly { + if !req.ReadOnly && !streamingNonUpdate { endpoints := filterInvalidIPv6Endpoints(req.Endpoints) node.Endpoints = endpoints node.DiscoKey = req.DiscoKey @@ -1337,9 +1409,9 @@ func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse, } if masqIP := nodeMasqs[p.Key]; masqIP.IsValid() { if masqIP.Is6() { - p.SelfNodeV6MasqAddrForThisPeer = ptr.To(masqIP) + p.SelfNodeV6MasqAddrForThisPeer = new(masqIP) } else { - p.SelfNodeV4MasqAddrForThisPeer = ptr.To(masqIP) + p.SelfNodeV4MasqAddrForThisPeer = new(masqIP) } } p.IsJailed = jailed[p.Key] @@ -1365,6 +1437,9 @@ func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse, p.PrimaryRoutes = routes p.AllowedIPs = append(p.AllowedIPs, routes...) } + if s.AllOnline { + p.Online = new(true) + } res.Peers = append(res.Peers, p) } @@ -1413,15 +1488,29 @@ func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse, res.Node.PrimaryRoutes = s.nodeSubnetRoutes[nk] res.Node.AllowedIPs = append(res.Node.Addresses, s.nodeSubnetRoutes[nk]...) - // Consume a PingRequest while protected by mutex if it exists - switch m := s.msgToSend[nk].(type) { - case *tailcfg.PingRequest: - res.PingRequest = m - delete(s.msgToSend, nk) + // Consume a PingRequest at the head of the queue, if any. + if q := s.msgToSend[nk]; len(q) > 0 { + if pr, ok := q[0].(*tailcfg.PingRequest); ok { + res.PingRequest = pr + s.popMsgToSendLocked(nk) + } } return res, nil } +// popMsgToSendLocked pops the head of the per-node message queue. +// s.mu must be held. +func (s *Server) popMsgToSendLocked(nk key.NodePublic) { + q := s.msgToSend[nk] + if len(q) <= 1 { + delete(s.msgToSend, nk) + return + } + // Zero the head to allow GC of any large referenced response. + q[0] = nil + s.msgToSend[nk] = q[1:] +} + func (s *Server) canGenerateAutomaticMapResponseFor(nk key.NodePublic) bool { s.mu.Lock() defer s.mu.Unlock() @@ -1431,22 +1520,21 @@ func (s *Server) canGenerateAutomaticMapResponseFor(nk key.NodePublic) bool { func (s *Server) hasPendingRawMapMessage(nk key.NodePublic) bool { s.mu.Lock() defer s.mu.Unlock() - _, ok := s.msgToSend[nk] - return ok + return len(s.msgToSend[nk]) > 0 } func (s *Server) takeRawMapMessage(nk key.NodePublic) (mapResJSON []byte, ok bool) { s.mu.Lock() defer s.mu.Unlock() - mr, ok := s.msgToSend[nk] - if !ok { + q := s.msgToSend[nk] + if len(q) == 0 { return nil, false } - delete(s.msgToSend, nk) + mr := q[0] + s.popMsgToSendLocked(nk) // If it's a bare PingRequest, wrap it in a MapResponse. - switch pr := mr.(type) { - case *tailcfg.PingRequest: + if pr, ok := mr.(*tailcfg.PingRequest); ok { mr = &tailcfg.MapResponse{PingRequest: pr} } @@ -1458,12 +1546,51 @@ func (s *Server) takeRawMapMessage(nk key.NodePublic) (mapResJSON []byte, ok boo return mapResJSON, true } +// AltMapStreamFunc is the type of [Server.AltMapStream]: a callback that +// takes over the serveMap handler entirely. The callback hand-builds and +// sends MapResponses via the provided [MapStreamWriter] and is responsible +// for keeping the stream alive until ctx is done. When set, the normal +// per-node map-stream state machine in serveMap is bypassed. +// +// The callback is invoked for every map long-poll, including the +// non-streaming "lite" polls controlclient issues to push HostInfo updates +// (req.Stream == false). Implementations that only care about the streaming +// long-poll typically respond to non-streaming polls with an empty +// MapResponse and return immediately. +// +// This hook is for benchmarks and stress tests that need to drive clients +// with a controlled sequence of responses. +type AltMapStreamFunc func(ctx context.Context, w MapStreamWriter, req *tailcfg.MapRequest) + +// MapStreamWriter is the interface passed to an [AltMapStreamFunc], +// letting the callback write framed MapResponse messages directly onto the +// long-poll HTTP response. +type MapStreamWriter interface { + // SendMapMessage encodes and writes msg as a single framed + // MapResponse on the stream. It respects the client's Compress flag + // (captured when the stream started). + SendMapMessage(msg *tailcfg.MapResponse) error +} + +// mapStreamSender implements [MapStreamWriter] for [Server.AltMapStream] +// callbacks. +type mapStreamSender struct { + s *Server + w http.ResponseWriter + compress bool +} + +func (m *mapStreamSender) SendMapMessage(msg *tailcfg.MapResponse) error { + return m.s.sendMapMsg(m.w, m.compress, msg) +} + func (s *Server) sendMapMsg(w http.ResponseWriter, compress bool, msg any) error { resBytes, err := s.encode(compress, msg) if err != nil { return err } - if len(resBytes) > 16<<20 { + const maxMapSize = 256 << 20 // 256MB + if len(resBytes) > maxMapSize { return fmt.Errorf("map message too big: %d", len(resBytes)) } var siz [4]byte @@ -1503,6 +1630,15 @@ func (s *Server) encode(compress bool, v any) (b []byte, err error) { return b, nil } +// SetOnMapRequest sets callback used for testing when a new mapRequest happens. +// Pass nil to remove the callback. +func (s *Server) SetOnMapRequest(f func(key.NodePublic)) { + s.mu.Lock() + defer s.mu.Unlock() + + s.onMapRequest = f +} + // filterInvalidIPv6Endpoints removes invalid IPv6 endpoints from eps, // modify the slice in place, returning the potentially smaller subset (aliasing // the original memory). diff --git a/tstest/integration/testcontrol/testcontrol_test.go b/tstest/integration/testcontrol/testcontrol_test.go new file mode 100644 index 0000000000000..d3008cdb7f43b --- /dev/null +++ b/tstest/integration/testcontrol/testcontrol_test.go @@ -0,0 +1,132 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package testcontrol_test + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "tailscale.com/control/ts2021" + "tailscale.com/control/tsp" + "tailscale.com/net/tsdial" + "tailscale.com/tailcfg" + "tailscale.com/tstest/integration/testcontrol" + "tailscale.com/types/key" + "tailscale.com/util/must" +) + +// TestStreamingMapReqReadOnlyByVersion verifies that testcontrol matches +// production control's streaming-is-read-only semantics for clients at +// capability version >= 68. Per tailcfg.MapRequest.Stream docs, a streaming +// MapRequest from a cap>=68 client must be treated as read-only by the +// server (Endpoints/Hostinfo/DiscoKey are sent separately via a non-streaming +// /machine/map call), so the streaming MapRequest's zero-valued DiscoKey +// must not clobber the node's currently stored DiscoKey. +// +// For older (cap<68) clients, the streaming MapRequest is still a write and +// writes do happen, so DiscoKey=zero in the request does clobber. +func TestStreamingMapReqReadOnlyByVersion(t *testing.T) { + tests := []struct { + version tailcfg.CapabilityVersion + wantClobber bool + }{ + {67, true}, // pre-cap-68: streaming is a write, DiscoKey=zero clobbers. + {68, false}, // cap>=68: streaming is read-only, DiscoKey unchanged. + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("v%d", tt.version), func(t *testing.T) { + ctrl := &testcontrol.Server{} + ctrl.HTTPTestServer = httptest.NewUnstartedServer(ctrl) + ctrl.HTTPTestServer.Start() + t.Cleanup(ctrl.HTTPTestServer.Close) + baseURL := ctrl.HTTPTestServer.URL + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + serverKey := must.Get(tsp.DiscoverServerKey(ctx, baseURL)) + + // Register a node and push a known DiscoKey via SendMapUpdate + // (a non-streaming, unambiguously-a-write request). + nodeKey := key.NewNode() + machineKey := key.NewMachine() + wantDisco := key.NewDisco().Public() + + tc := must.Get(tsp.NewClient(tsp.ClientOpts{ + ServerURL: baseURL, + MachineKey: machineKey, + })) + defer tc.Close() + tc.SetControlPublicKey(serverKey) + must.Get(tc.Register(ctx, tsp.RegisterOpts{ + NodeKey: nodeKey, + Hostinfo: &tailcfg.Hostinfo{Hostname: "target"}, + })) + if err := tc.SendMapUpdate(ctx, tsp.SendMapUpdateOpts{ + NodeKey: nodeKey, + DiscoKey: wantDisco, + Hostinfo: &tailcfg.Hostinfo{Hostname: "target"}, + }); err != nil { + t.Fatalf("SendMapUpdate: %v", err) + } + if n := ctrl.Node(nodeKey.Public()); n == nil || n.DiscoKey != wantDisco { + t.Fatalf("pre: DiscoKey not set; node=%+v", n) + } + + // Fire a streaming MapRequest with the chosen Version and a + // zero DiscoKey. Use ts2021 directly because tsp.Map hardcodes + // Version to tailcfg.CurrentCapabilityVersion. + nc := must.Get(ts2021.NewClient(ts2021.ClientOpts{ + ServerURL: baseURL, + PrivKey: machineKey, + ServerPubKey: serverKey, + Dialer: tsdial.NewFromFuncForDebug(t.Logf, (&net.Dialer{}).DialContext), + })) + defer nc.Close() + + body := must.Get(json.Marshal(&tailcfg.MapRequest{ + Version: tt.version, + NodeKey: nodeKey.Public(), + Stream: true, + // DiscoKey intentionally zero. + })) + reqURL := strings.Replace(baseURL+"/machine/map", "http:", "https:", 1) + reqCtx, reqCancel := context.WithCancel(ctx) + defer reqCancel() + req := must.Get(http.NewRequestWithContext(reqCtx, "POST", reqURL, bytes.NewReader(body))) + ts2021.AddLBHeader(req, nodeKey.Public()) + + // nc.Do returns once response headers arrive, which in + // testcontrol's serveMap is AFTER the write branch has run + // (or been skipped). So by the time this returns, any write + // this request is going to do has already happened. + res, err := nc.Do(req) + if err != nil { + t.Fatalf("nc.Do: %v", err) + } + res.Body.Close() // tears down the streaming session server-side + + got := ctrl.Node(nodeKey.Public()) + if got == nil { + t.Fatal("node disappeared") + } + switch { + case tt.wantClobber && !got.DiscoKey.IsZero(): + t.Errorf("v%d: expected DiscoKey clobbered to zero, got %v", tt.version, got.DiscoKey) + case !tt.wantClobber && got.DiscoKey != wantDisco: + t.Errorf("v%d: DiscoKey changed from %v to %v; should have been left alone", + tt.version, wantDisco, got.DiscoKey) + } + }) + } +} diff --git a/tstest/integration/vms/distros.go b/tstest/integration/vms/distros.go index 94f11c77aac5d..b6312dba45c46 100644 --- a/tstest/integration/vms/distros.go +++ b/tstest/integration/vms/distros.go @@ -35,11 +35,10 @@ func (d *Distro) InstallPre() string { return ` - [ dnf, install, "-y", iptables ]` case "apt": - return ` - [ apt-get, update ] - - [ apt-get, "-y", install, curl, "apt-transport-https", gnupg2 ]` + return ` - [ apt-get, "-y", install, curl, "apt-transport-https", gnupg2 ]` case "apk": - return ` - [ apk, "-U", add, curl, "ca-certificates", iptables, ip6tables ] + return ` - [ apk, add, curl, "ca-certificates", iptables, ip6tables ] - [ modprobe, tun ]` } diff --git a/tstest/integration/vms/vms_test.go b/tstest/integration/vms/vms_test.go index 5ebb12b71032b..ed64acb91f4eb 100644 --- a/tstest/integration/vms/vms_test.go +++ b/tstest/integration/vms/vms_test.go @@ -355,7 +355,7 @@ func (h *Harness) testDistro(t *testing.T, d Distro, ipm ipMapping) { }) }) - t.Run("tailscale status", func(t *testing.T) { + t.Run("tailscale-status", func(t *testing.T) { dur := 100 * time.Millisecond var outp []byte var err error @@ -364,7 +364,7 @@ func (h *Harness) testDistro(t *testing.T, d Distro, ipm ipMapping) { // starts with testcontrol sometimes there can be up to a few seconds where // tailscaled is in an unknown state on these virtual machines. This exponential // delay loop should delay long enough for tailscaled to be ready. - for count := 0; count < 10; count++ { + for range 10 { sess := getSession(t, cli) outp, err = sess.CombinedOutput("tailscale status") @@ -383,7 +383,7 @@ func (h *Harness) testDistro(t *testing.T, d Distro, ipm ipMapping) { t.Fatalf("error: %v", err) }) - t.Run("dump routes", func(t *testing.T) { + t.Run("dump-routes", func(t *testing.T) { sess, err := cli.NewSession() if err != nil { t.Fatal(err) diff --git a/tstest/integration/whois_test.go b/tstest/integration/whois_test.go new file mode 100644 index 0000000000000..b4e99a547028b --- /dev/null +++ b/tstest/integration/whois_test.go @@ -0,0 +1,152 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package integration + +import ( + "context" + "errors" + "fmt" + "net" + "testing" + "time" + + "tailscale.com/tstest" +) + +// TestUserspaceWhoIsProxyMap verifies that WhoIs lookups work via the +// proxymap in userspace-networking mode. It sets up two nodes (n1 and +// n2), starts a TCP listener on localhost, and has n1 connect to n2's +// Tailscale IP on the listener's port via "tailscale nc". Node n2's +// netstack forwards the connection to localhost, and the listener +// calls WhoIs on n2's LocalAPI to identify the remote peer as n1. +func TestUserspaceWhoIsProxyMap(t *testing.T) { + tstest.Shard(t) + tstest.Parallel(t) + env := NewTestEnv(t) + + n1 := NewTestNode(t, env) + d1 := n1.StartDaemon() + + n2 := NewTestNode(t, env) + d2 := n2.StartDaemon() + + n1.AwaitListening() + n2.AwaitListening() + n1.MustUp() + n2.MustUp() + n1.AwaitRunning() + n2.AwaitRunning() + + // Wait for n1 to see n2 as a peer. + if err := tstest.WaitFor(10*time.Second, func() error { + st := n1.MustStatus() + if len(st.Peer) == 0 { + return errors.New("no peers") + } + return nil + }); err != nil { + t.Fatal(err) + } + + // Verify the two nodes have different users. If they were the + // same user, a WhoIs hit could pass trivially. + st1 := n1.MustStatus() + st2 := n2.MustStatus() + if st1.Self.UserID == st2.Self.UserID { + t.Fatalf("n1 and n2 have the same UserID %v; want different users", st1.Self.UserID) + } + t.Logf("n1: UserID=%v", st1.Self.UserID) + t.Logf("n2: UserID=%v", st2.Self.UserID) + + n2IP := n2.AwaitIP4() + t.Logf("n2 IP: %v", n2IP) + + // Start a TCP listener on localhost:0. When n1 connects to n2's + // Tailscale IP on this port, n2's netstack (userspace networking) + // will forward the connection to 127.0.0.1:. The listener + // uses n2's LocalAPI WhoIs to identify the connecting peer. + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + port := ln.Addr().(*net.TCPAddr).Port + t.Logf("listener on port %d", port) + + type result struct { + msg string + err error + } + resultCh := make(chan result, 1) + + go func() { + conn, err := ln.Accept() + if err != nil { + resultCh <- result{err: fmt.Errorf("accept: %w", err)} + return + } + defer conn.Close() + + // The RemoteAddr is 127.0.0.1:, the local side of + // n2's netstack dial. WhoIs on n2 should resolve this via the + // proxymap to n1's Tailscale identity. + ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) + defer cancel() + who, err := n2.LocalClient().WhoIs(ctx, conn.RemoteAddr().String()) + if err != nil { + resultCh <- result{err: fmt.Errorf("WhoIs(%q): %w", conn.RemoteAddr(), err)} + return + } + if who.Node == nil { + resultCh <- result{err: errors.New("WhoIs returned nil Node")} + return + } + if who.UserProfile == nil { + resultCh <- result{err: errors.New("WhoIs returned nil UserProfile")} + return + } + + msg := fmt.Sprintf("Hello, %s (%v %v)!", + who.UserProfile.LoginName, who.Node.Name, who.Node.ID) + conn.Write([]byte(msg)) + resultCh <- result{msg: msg} + }() + + // Use "tailscale nc" on n1 to connect to n2's Tailscale IP on + // the listener port. This goes through n1's tailscaled, over + // wireguard to n2's netstack, which dials localhost:. + // + // We need to keep stdin open so nc doesn't exit before reading + // the server's response (nc returns on the first goroutine to + // complete: stdin→conn or conn→stdout). + cmd := n1.TailscaleForOutput("nc", n2IP.String(), fmt.Sprint(port)) + stdin, err := cmd.StdinPipe() + if err != nil { + t.Fatal(err) + } + out, err := cmd.Output() + stdin.Close() + if err != nil { + t.Fatalf("tailscale nc: %v", err) + } + + // Verify the listener goroutine completed without error. + r := <-resultCh + if r.err != nil { + t.Fatal(r.err) + } + + got := string(out) + if got != r.msg { + t.Fatalf("nc output %q doesn't match server-sent message %q", got, r.msg) + } + const wantPrefix = "Hello, user-1@fake-control.example.net (" + if len(got) < len(wantPrefix) || got[:len(wantPrefix)] != wantPrefix { + t.Errorf("got %q, want prefix %q", got, wantPrefix) + } + t.Logf("response: %s", got) + + d1.MustCleanShutdown(t) + d2.MustCleanShutdown(t) +} diff --git a/tstest/iosdeps/iosdeps.go b/tstest/iosdeps/iosdeps.go index f6290af676e97..a1279e20baa15 100644 --- a/tstest/iosdeps/iosdeps.go +++ b/tstest/iosdeps/iosdeps.go @@ -4,28 +4,36 @@ // Package iosdeps is a just a list of the packages we import on iOS, to let us // test that our transitive closure of dependencies on iOS doesn't accidentally // grow too large, as we've historically been memory constrained there. +// +// It is intended to mirror the imports of the ipn-go-bridge package in the +// private "corp" repository (the Go side of the iOS / macOS app). package iosdeps import ( _ "bufio" _ "bytes" - _ "context" - _ "crypto/rand" + _ "crypto" + _ "crypto/ecdsa" + _ "crypto/elliptic" _ "crypto/sha256" + _ "encoding/base64" _ "encoding/json" _ "errors" _ "fmt" _ "io" - _ "io/fs" _ "log" _ "math" _ "net" _ "net/http" + _ "net/netip" + _ "net/url" _ "os" _ "os/signal" _ "path/filepath" _ "runtime" _ "runtime/debug" + _ "slices" + _ "strconv" _ "strings" _ "sync" _ "sync/atomic" @@ -35,24 +43,48 @@ import ( _ "github.com/tailscale/wireguard-go/device" _ "github.com/tailscale/wireguard-go/tun" - _ "go4.org/mem" _ "golang.org/x/sys/unix" + _ "tailscale.com/client/tailscale/apitype" + _ "tailscale.com/drive/driveimpl" + _ "tailscale.com/envknob" + _ "tailscale.com/feature/condregister" + _ "tailscale.com/feature/syspolicy" + _ "tailscale.com/feature/taildrop" _ "tailscale.com/hostinfo" _ "tailscale.com/ipn" + _ "tailscale.com/ipn/ipnauth" _ "tailscale.com/ipn/ipnlocal" _ "tailscale.com/ipn/localapi" + _ "tailscale.com/logpolicy" _ "tailscale.com/logtail" _ "tailscale.com/logtail/filch" _ "tailscale.com/net/dns" - _ "tailscale.com/net/netaddr" + _ "tailscale.com/net/netmon" + _ "tailscale.com/net/netutil" + _ "tailscale.com/net/tsaddr" _ "tailscale.com/net/tsdial" + _ "tailscale.com/net/tshttpproxy" _ "tailscale.com/net/tstun" _ "tailscale.com/paths" + _ "tailscale.com/safesocket" + _ "tailscale.com/tsd" _ "tailscale.com/types/empty" + _ "tailscale.com/types/key" + _ "tailscale.com/types/lazy" _ "tailscale.com/types/logger" + _ "tailscale.com/types/logid" + _ "tailscale.com/types/netmap" _ "tailscale.com/util/clientmetric" _ "tailscale.com/util/dnsname" + _ "tailscale.com/util/eventbus" + _ "tailscale.com/util/must" + _ "tailscale.com/util/set" + _ "tailscale.com/util/syspolicy" + _ "tailscale.com/util/syspolicy/pkey" + _ "tailscale.com/util/syspolicy/setting" + _ "tailscale.com/util/syspolicy/source" _ "tailscale.com/version" _ "tailscale.com/wgengine" + _ "tailscale.com/wgengine/netstack" _ "tailscale.com/wgengine/router" ) diff --git a/tstest/kernel_linux.go b/tstest/kernel_linux.go index ab7c0d529fc13..ed48fd071f251 100644 --- a/tstest/kernel_linux.go +++ b/tstest/kernel_linux.go @@ -20,8 +20,13 @@ func KernelVersion() (major, minor, patch int) { return 0, 0, 0 } release := unix.ByteSliceToString(uname.Release[:]) + return parseKernelVersion(release) +} - // Parse version string (e.g., "5.15.0-...") +// parseKernelVersion parses a Linux kernel version string like "6.12.73+deb13-amd64" +// or "5.15.0-76-generic" and returns the major, minor, and patch components. +// It returns (0, 0, 0) if the version cannot be parsed. +func parseKernelVersion(release string) (major, minor, patch int) { parts := strings.Split(release, ".") if len(parts) < 3 { return 0, 0, 0 @@ -37,9 +42,12 @@ func KernelVersion() (major, minor, patch int) { return 0, 0, 0 } - // Patch version may have additional info after a hyphen (e.g., "0-76-generic") - // Extract just the numeric part before any hyphen - patchStr, _, _ := strings.Cut(parts[2], "-") + // Patch version may have additional info after a hyphen or plus (e.g., "0-76-generic" or "41+deb13-amd64") + // Extract just the numeric part before any hyphen or plus + patchStr := parts[2] + if idx := strings.IndexAny(patchStr, "-+"); idx != -1 { + patchStr = patchStr[:idx] + } patch, err = strconv.Atoi(patchStr) if err != nil { diff --git a/tstest/kernel_linux_test.go b/tstest/kernel_linux_test.go new file mode 100644 index 0000000000000..9445ebe2c3866 --- /dev/null +++ b/tstest/kernel_linux_test.go @@ -0,0 +1,34 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package tstest + +import "testing" + +func TestParseKernelVersion(t *testing.T) { + tests := []struct { + release string + major, minor, patch int + }{ + {"5.15.0-76-generic", 5, 15, 0}, + {"6.12.73+deb13-amd64", 6, 12, 73}, + {"6.1.0-18-amd64", 6, 1, 0}, + {"5.4.0", 5, 4, 0}, + {"6.8.12", 6, 8, 12}, + {"4.19.0+1", 4, 19, 0}, + {"6.12.41+deb13-amd64", 6, 12, 41}, + {"", 0, 0, 0}, + {"not-a-version", 0, 0, 0}, + {"1.2", 0, 0, 0}, + {"a.b.c", 0, 0, 0}, + } + for _, tt := range tests { + major, minor, patch := parseKernelVersion(tt.release) + if major != tt.major || minor != tt.minor || patch != tt.patch { + t.Errorf("parseKernelVersion(%q) = (%d, %d, %d), want (%d, %d, %d)", + tt.release, major, minor, patch, tt.major, tt.minor, tt.patch) + } + } +} diff --git a/tstest/largetailnet/largetailnet.go b/tstest/largetailnet/largetailnet.go new file mode 100644 index 0000000000000..73ec2da805051 --- /dev/null +++ b/tstest/largetailnet/largetailnet.go @@ -0,0 +1,265 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// Package largetailnet provides reusable building blocks for in-process +// benchmarks and stress tests that drive a single tailnet client (typically a +// [tsnet.Server]) with a synthetic large-tailnet MapResponse stream. +// +// A [Streamer] takes over the map long-poll on a [testcontrol.Server] via the +// AltMapStream hook: it sends one initial MapResponse announcing the self +// node and N synthetic peers, and then forwards caller-supplied delta +// MapResponses on the same stream until ctx is done. +// +// The package is designed so that a benchmark can: +// +// - Build a [Streamer] with the desired peer count. +// - Stand up a [testcontrol.Server] with the streamer's [Streamer.AltMapStream] +// installed. +// - Stand up a [tsnet.Server] pointed at the testcontrol; its Up call +// blocks until the initial netmap has been processed. +// - Reset the benchmark timer and drive add/remove deltas with +// [Streamer.SendDelta] and [Streamer.AllocPeer]. +package largetailnet + +import ( + "context" + cryptorand "crypto/rand" + "fmt" + "net/netip" + "sync/atomic" + "time" + + "go4.org/mem" + "tailscale.com/net/tsaddr" + "tailscale.com/tailcfg" + "tailscale.com/tstest/integration/testcontrol" + "tailscale.com/types/key" +) + +// SelfUserID is the synthetic [tailcfg.UserID] assigned to the self node and +// to every initial peer produced by [Streamer]. Tests that build their own +// peers via [MakePeer] should pass this value. +const SelfUserID tailcfg.UserID = 1_000_000 + +// Streamer drives a controlled MapResponse stream to a single client via +// [testcontrol.Server.AltMapStream]. It synthesizes an initial netmap with N +// peers and forwards caller-supplied delta MapResponses on the same stream. +// +// A Streamer is single-shot: it expects exactly one map long-poll over its +// lifetime and is not safe for re-use across multiple clients. +type Streamer struct { + n int + derpMap *tailcfg.DERPMap + + started chan struct{} // closed when the alt-map-stream callback first fires + initialDone chan struct{} // closed after initial MapResponse has been written + deltas chan *tailcfg.MapResponse + + // nextID is the next free node ID. It starts at N+2 (1 is the self + // node, 2..N+1 are the initial peers) and is bumped by AllocPeer. + nextID atomic.Int64 +} + +// New constructs a Streamer that will produce an initial netmap with n peers +// and a self node when its AltMapStream callback first fires. derpMap is +// included verbatim in the initial MapResponse. +func New(n int, derpMap *tailcfg.DERPMap) *Streamer { + s := &Streamer{ + n: n, + derpMap: derpMap, + started: make(chan struct{}), + initialDone: make(chan struct{}), + // Buffered so a benchmark loop body that does send-then-wait + // doesn't block on the channel under steady state. + deltas: make(chan *tailcfg.MapResponse, 64), + } + s.nextID.Store(int64(n) + 2) + return s +} + +// AltMapStream returns a callback suitable for [testcontrol.Server.AltMapStream]. +// On the first streaming long-poll it sends the initial big MapResponse and +// then forwards deltas enqueued via [Streamer.SendDelta] until ctx is done. +// Non-streaming "lite" polls are answered with an empty MapResponse so they +// complete quickly. The streamer is single-shot: any later streaming polls +// are kept alive but produce no further messages. +func (s *Streamer) AltMapStream() testcontrol.AltMapStreamFunc { + return func(ctx context.Context, w testcontrol.MapStreamWriter, req *tailcfg.MapRequest) { + if !req.Stream { + _ = w.SendMapMessage(&tailcfg.MapResponse{}) + return + } + + select { + case <-s.started: + // Re-poll after the original stream ended. Keep the + // connection alive so the client doesn't churn. + <-ctx.Done() + return + default: + close(s.started) + } + + if err := s.sendInitial(w, req); err != nil { + // Make the failure loud rather than wedging the + // caller's [tsnet.Server.Up] on a silent retry loop. + panic(fmt.Sprintf("largetailnet: sendInitial: %v", err)) + } + close(s.initialDone) + + for { + select { + case <-ctx.Done(): + return + case mr := <-s.deltas: + if err := w.SendMapMessage(mr); err != nil { + <-ctx.Done() + return + } + } + } + } +} + +// AwaitInitialSent blocks until the initial big MapResponse has been written +// to the wire. Note this is not the same as "the client has finished +// processing it"; for that, callers should rely on [tsnet.Server.Up] +// returning, or watch the IPN bus. +func (s *Streamer) AwaitInitialSent(ctx context.Context) error { + select { + case <-ctx.Done(): + return ctx.Err() + case <-s.initialDone: + return nil + } +} + +// SendDelta enqueues mr for delivery on the active MapResponse stream. It +// blocks if the internal queue is full or the stream hasn't started yet. +func (s *Streamer) SendDelta(ctx context.Context, mr *tailcfg.MapResponse) error { + select { + case <-ctx.Done(): + return ctx.Err() + case s.deltas <- mr: + return nil + } +} + +// AllocPeer returns a fresh synthetic peer node with a never-before-used +// [tailcfg.NodeID]. It's intended for use in PeersChanged deltas. +func (s *Streamer) AllocPeer() *tailcfg.Node { + return MakePeer(tailcfg.NodeID(s.nextID.Add(1)-1), SelfUserID) +} + +// SelfNodeID returns the [tailcfg.NodeID] used for the self node in the +// initial netmap. +func (s *Streamer) SelfNodeID() tailcfg.NodeID { return 1 } + +// sendInitial writes the big initial MapResponse with s.n peers. +func (s *Streamer) sendInitial(w testcontrol.MapStreamWriter, req *tailcfg.MapRequest) error { + selfNodeID := s.SelfNodeID() + selfIP4 := node4(selfNodeID) + selfIP6 := node6(selfNodeID) + + peers := make([]*tailcfg.Node, 0, s.n) + for i := 0; i < s.n; i++ { + peers = append(peers, MakePeer(tailcfg.NodeID(i+2), SelfUserID)) + } + + now := time.Now().UTC() + selfNode := &tailcfg.Node{ + ID: selfNodeID, + StableID: "largetailnet-self", + Name: "self.largetailnet.ts.net.", + User: SelfUserID, + Key: req.NodeKey, + KeyExpiry: now.Add(24 * time.Hour), + Machine: randMachineKey(), // fake; client doesn't verify + DiscoKey: req.DiscoKey, + MachineAuthorized: true, + Addresses: []netip.Prefix{selfIP4, selfIP6}, + AllowedIPs: []netip.Prefix{selfIP4, selfIP6}, + CapMap: map[tailcfg.NodeCapability][]tailcfg.RawMessage{}, + } + + initial := &tailcfg.MapResponse{ + KeepAlive: false, + Node: selfNode, + DERPMap: s.derpMap, + Peers: peers, + PacketFilter: []tailcfg.FilterRule{{ + // Accept-all filter so the client isn't logging packet-filter + // failures; this is a benchmark harness, not a security test. + SrcIPs: []string{"*"}, + DstPorts: []tailcfg.NetPortRange{{IP: "*", Ports: tailcfg.PortRangeAny}}, + }}, + DNSConfig: &tailcfg.DNSConfig{}, + Domain: "largetailnet.ts.net", + UserProfiles: []tailcfg.UserProfile{{ + ID: SelfUserID, + LoginName: "largetailnet@example.com", + DisplayName: "largetailnet", + }}, + ControlTime: &now, + } + return w.SendMapMessage(initial) +} + +// MakePeer constructs a synthetic [tailcfg.Node] for the given NodeID and +// UserID. The peer's node/disco/machine keys are derived from random bytes +// via the *PublicFromRaw32 constructors rather than via key.New*().Public(), +// which avoids the per-peer Curve25519 ScalarBaseMult and lets the harness +// construct hundreds of thousands of peers in a few hundred milliseconds. +// The client never crypto-validates these keys in the bench, so opaque +// random bytes are sufficient. +func MakePeer(nid tailcfg.NodeID, user tailcfg.UserID) *tailcfg.Node { + v4, v6 := node4(nid), node6(nid) + name := fmt.Sprintf("peer-%d", nid) + return &tailcfg.Node{ + ID: nid, + StableID: tailcfg.StableNodeID(name), + Name: name + ".largetailnet.ts.net.", + Key: randNodeKey(), + MachineAuthorized: true, + DiscoKey: randDiscoKey(), + Machine: randMachineKey(), + Addresses: []netip.Prefix{v4, v6}, + AllowedIPs: []netip.Prefix{v4, v6}, + User: user, + // Hostinfo must be non-nil: LocalBackend.populatePeerStatus + // dereferences it via HostinfoView.Hostname unconditionally. + Hostinfo: (&tailcfg.Hostinfo{Hostname: name}).View(), + } +} + +func randNodeKey() key.NodePublic { + var b [32]byte + cryptorand.Read(b[:]) + return key.NodePublicFromRaw32(mem.B(b[:])) +} + +func randDiscoKey() key.DiscoPublic { + var b [32]byte + cryptorand.Read(b[:]) + return key.DiscoPublicFromRaw32(mem.B(b[:])) +} + +func randMachineKey() key.MachinePublic { + var b [32]byte + cryptorand.Read(b[:]) + return key.MachinePublicFromRaw32(mem.B(b[:])) +} + +func node4(nid tailcfg.NodeID) netip.Prefix { + return netip.PrefixFrom( + netip.AddrFrom4([4]byte{100, 100 + byte(nid>>16), byte(nid >> 8), byte(nid)}), + 32) +} + +func node6(nid tailcfg.NodeID) netip.Prefix { + a := tsaddr.TailscaleULARange().Addr().As16() + a[13] = byte(nid >> 16) + a[14] = byte(nid >> 8) + a[15] = byte(nid) + return netip.PrefixFrom(netip.AddrFrom16(a), 128) +} diff --git a/tstest/largetailnet/largetailnet_test.go b/tstest/largetailnet/largetailnet_test.go new file mode 100644 index 0000000000000..07f67df820014 --- /dev/null +++ b/tstest/largetailnet/largetailnet_test.go @@ -0,0 +1,218 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package largetailnet_test + +import ( + "context" + "flag" + "net/http/httptest" + "os" + "path/filepath" + "runtime" + "testing" + "time" + + "tailscale.com/ipn/store/mem" + "tailscale.com/tailcfg" + "tailscale.com/tsnet" + "tailscale.com/tstest/integration" + "tailscale.com/tstest/integration/testcontrol" + "tailscale.com/tstest/largetailnet" + "tailscale.com/types/logger" +) + +// tsnet.Server.Up handles the wait-for-ipn.Running step itself: it +// subscribes to the IPN bus with NotifyInitialState and blocks until State +// reaches ipn.Running, which by definition means a netmap has been applied. +// We don't redo that work here. + +var ( + flagActuallyTest = flag.Bool("actually-test-giant-tailnet", false, + "if set, run the BenchmarkGiantTailnet* benchmarks; otherwise they are skipped") + flagN = flag.Int("giant-tailnet-n", 250_000, + "size of the initial netmap (peer count) for BenchmarkGiantTailnet*") + flagBenchVerbose = flag.Bool("giant-tailnet-verbose", false, + "if set, log tsnet output and DERP setup to stderr") +) + +// BenchmarkGiantTailnet measures the per-delta CPU cost of a tailnet client +// processing peer-add/peer-remove deltas in steady state, with no IPN bus +// subscribers attached. This represents the headless-tailscaled workload +// (Linux subnet routers, container sidecars, ...) where the LocalBackend +// does not pay for fanning Notify.NetMap out to GUI watchers. +// +// Use [BenchmarkGiantTailnetBusWatcher] for the GUI-client workload. +// +// The benchmark is opt-in via --actually-test-giant-tailnet. +func BenchmarkGiantTailnet(b *testing.B) { + if !*flagActuallyTest { + b.Skip("set --actually-test-giant-tailnet to run this benchmark") + } + benchGiantTailnet(b, false) +} + +// BenchmarkGiantTailnetBusWatcher is like [BenchmarkGiantTailnet] but +// attaches one [local.Client.WatchIPNBus] subscriber for the duration of the +// benchmark. The Notify-fan-out cost (notably Notify.NetMap encoding to +// every watcher on every full-rebuild path) is therefore included in the +// per-delta measurement, which approximates the GUI-client workload. +// +// The benchmark is opt-in via --actually-test-giant-tailnet. +func BenchmarkGiantTailnetBusWatcher(b *testing.B) { + if !*flagActuallyTest { + b.Skip("set --actually-test-giant-tailnet to run this benchmark") + } + benchGiantTailnet(b, true) +} + +// benchGiantTailnet is the shared body of the BenchmarkGiantTailnet* +// benchmarks. Setup is entirely in-process: a [testcontrol.Server] hosts +// the control plane, a [tsnet.Server] hosts the client, and a +// [largetailnet.Streamer] hijacks the map long-poll to drive an exact +// MapResponse sequence. +// +// Each loop iteration sends one [tailcfg.MapResponse] with PeersChanged +// (a fresh peer) and PeersRemoved (the previous fresh peer), then waits +// for the client to apply it. Net peer count stays at flagN throughout the +// loop. +// +// The wait mechanism differs by variant: +// +// - busWatcher=false: block on a channel returned by +// [ipnlocal.LocalBackend.AwaitNodeKeyForTest] (reached via +// [tsnet.TestHooks]). The channel is closed by LocalBackend the moment +// the just-added peer's key appears in the netmap, so the wait has zero +// polling overhead. +// - busWatcher=true: drain Notify events from the bus subscription, since +// a Notify firing is exactly the side-effect we want to amortize into +// the per-delta measurement. +// +// Recommended invocation for profiling on unmodified main: +// +// go test ./tstest/largetailnet/ -run=^$ \ +// -bench='BenchmarkGiantTailnet(BusWatcher)?$' \ +// -benchtime=2000x -timeout=10m \ +// --actually-test-giant-tailnet \ +// --giant-tailnet-n=250000 \ +// -cpuprofile=/tmp/giant.cpu.pprof +func benchGiantTailnet(b *testing.B, busWatcher bool) { + logf := logger.Discard + if *flagBenchVerbose { + logf = b.Logf + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute) + b.Cleanup(cancel) + + derpMap := integration.RunDERPAndSTUN(b, logf, "127.0.0.1") + + streamer := largetailnet.New(*flagN, derpMap) + + ctrl := &testcontrol.Server{ + DERPMap: derpMap, + DNSConfig: &tailcfg.DNSConfig{}, + AltMapStream: streamer.AltMapStream(), + Logf: logf, + } + ctrl.HTTPTestServer = httptest.NewUnstartedServer(ctrl) + ctrl.HTTPTestServer.Start() + b.Cleanup(ctrl.HTTPTestServer.Close) + controlURL := ctrl.HTTPTestServer.URL + b.Logf("testcontrol listening on %s", controlURL) + + tmp := filepath.Join(b.TempDir(), "tsnet") + if err := os.MkdirAll(tmp, 0755); err != nil { + b.Fatal(err) + } + + s := &tsnet.Server{ + Dir: tmp, + ControlURL: controlURL, + Hostname: "largetailnet-bench", + Store: new(mem.Store), + Ephemeral: true, + Logf: logf, + } + b.Cleanup(func() { s.Close() }) + + // tsnet.Server.Up blocks until the backend reaches Running, which + // requires the initial flagN-peer MapResponse to have been processed. + upStart := time.Now() + if _, err := s.Up(ctx); err != nil { + b.Fatalf("tsnet.Server.Up: %v", err) + } + b.Logf("initial %d-peer netmap processed in %v", *flagN, time.Since(upStart)) + + lc, err := s.LocalClient() + if err != nil { + b.Fatalf("LocalClient: %v", err) + } + lb := tsnet.TestHooks.LocalBackend(s) + + var notifyCh chan struct{} + if busWatcher { + bw, err := lc.WatchIPNBus(ctx, 0) + if err != nil { + b.Fatalf("WatchIPNBus: %v", err) + } + b.Cleanup(func() { bw.Close() }) + notifyCh = make(chan struct{}, 1024) + go func() { + for { + n, err := bw.Next() + if err != nil { + return + } + if n.NetMap != nil || len(n.PeerChanges) > 0 { + select { + case notifyCh <- struct{}{}: + default: + } + } + } + }() + } + + var prevAdded *tailcfg.Node + runtime.GC() + + b.ResetTimer() + for b.Loop() { + added := streamer.AllocPeer() + mr := &tailcfg.MapResponse{ + PeersChanged: []*tailcfg.Node{added}, + } + if prevAdded != nil { + mr.PeersRemoved = []tailcfg.NodeID{prevAdded.ID} + } + prevAdded = added + + if err := streamer.SendDelta(ctx, mr); err != nil { + b.Fatalf("SendDelta: %v", err) + } + + if busWatcher { + // A Notify firing is itself part of the workload we + // want to measure on this variant. + select { + case <-notifyCh: + case <-time.After(10 * time.Second): + b.Fatal("timed out waiting for notify") + case <-ctx.Done(): + b.Fatalf("ctx done waiting for notify: %v", ctx.Err()) + } + } else { + // Block on the LocalBackend's test-only signal that + // the just-added peer key has landed in the netmap. + // No polling, no notify fan-out cost. + select { + case <-lb.AwaitNodeKeyForTest(added.Key): + case <-time.After(10 * time.Second): + b.Fatalf("timed out waiting for node key %v", added.Key) + case <-ctx.Done(): + b.Fatalf("ctx done waiting for node key: %v", ctx.Err()) + } + } + } +} diff --git a/tstest/natlab/natlab.go b/tstest/natlab/natlab.go index add812d8fe6e3..b66779eebe7a3 100644 --- a/tstest/natlab/natlab.go +++ b/tstest/natlab/natlab.go @@ -18,6 +18,7 @@ import ( "net" "net/netip" "os" + "slices" "sort" "strconv" "sync" @@ -247,12 +248,7 @@ func (f *Interface) String() string { // Contains reports whether f contains ip as an IP. func (f *Interface) Contains(ip netip.Addr) bool { - for _, v := range f.ips { - if ip == v { - return true - } - } - return false + return slices.Contains(f.ips, ip) } type routeEntry struct { @@ -348,10 +344,8 @@ func (m *Machine) isLocalIP(ip netip.Addr) bool { m.mu.Lock() defer m.mu.Unlock() for _, intf := range m.interfaces { - for _, iip := range intf.ips { - if ip == iip { - return true - } + if slices.Contains(intf.ips, ip) { + return true } } return false @@ -565,7 +559,7 @@ func (m *Machine) interfaceForIP(ip netip.Addr) (*Interface, error) { func (m *Machine) pickEphemPort() (port uint16, err error) { m.mu.Lock() defer m.mu.Unlock() - for tries := 0; tries < 500; tries++ { + for range 500 { port := uint16(rand.IntN(32<<10) + 32<<10) if !m.portInUseLocked(port) { return port, nil diff --git a/tstest/natlab/vmtest/assets/event.html b/tstest/natlab/vmtest/assets/event.html new file mode 100644 index 0000000000000..a5f5966730fb5 --- /dev/null +++ b/tstest/natlab/vmtest/assets/event.html @@ -0,0 +1,45 @@ +{{if eq .Type "test_status"}} +{{.Message}} ({{.Detail}}) +{{end}} + +{{if eq .Type "step_changed"}} +
+ {{.Step.Status.Icon}} + {{.Step.Name}} + {{formatDuration .Step.Elapsed}} +
+{{end}} + +{{if eq .Type "console_output"}} +
{{ansi .Message}} +
+{{end}} + +{{if eq .Type "dhcp_discover"}} +Discover sent +{{end}} + +{{if eq .Type "dhcp_offer"}} +Offered {{.Detail}} +{{end}} + +{{if eq .Type "dhcp_request"}} +Requesting {{.Detail}} +{{end}} + +{{if eq .Type "dhcp_ack"}} +Got {{.Detail}} +{{end}} + +{{if eq .Type "tailscale"}} +{{.Detail}} +{{end}} + +{{if eq .Type "screenshot"}} +
+{{end}} + +{{if ne .Type "screenshot"}} +
{{.Time.Format "15:04:05.000"}} {{if .NodeName}}[{{.NodeName}}] {{end}}{{.Message}}{{if .Detail}} {{.Detail}}{{end}}
+
+{{end}} diff --git a/tstest/natlab/vmtest/assets/index.html b/tstest/natlab/vmtest/assets/index.html new file mode 100644 index 0000000000000..044efffeef354 --- /dev/null +++ b/tstest/natlab/vmtest/assets/index.html @@ -0,0 +1,112 @@ + + + + + VMTest: {{.TestName}} + + + + + + +

VMTest: {{.TestName}} {{.TestStatus.State}} ({{formatDuration .TestStatus.Elapsed}})

+ +
+

Progress

+ {{range .Steps}} +
+ {{.Status.Icon}} + {{.Name}} + {{if ne .Status.String "pending"}}{{formatDuration .Elapsed}}{{end}} +
+ {{end}} +
+ +
+ {{range $node := .Nodes}} +
+
+ {{$node.Name}} + {{$node.OS}} +
+
+ {{range $i, $nic := $node.NICs}} +
+ DHCP{{if gt (len $node.NICs) 1}} ({{$nic.NetName}}){{end}}: + {{$nic.DHCP}} +
+ {{end}} + {{if $node.JoinsTailnet}} +
+ Tailscale: + {{$node.Tailscale}} +
+ {{end}} +
+
{{if $node.Screenshot}}{{end}}
+
{{range $node.Console}}{{ansi .}} +{{end}}
+
+ {{end}} +
+ +
+

Events

+
+
+ + + + + diff --git a/tstest/natlab/vmtest/assets/style.css b/tstest/natlab/vmtest/assets/style.css new file mode 100644 index 0000000000000..5970598b8afdd --- /dev/null +++ b/tstest/natlab/vmtest/assets/style.css @@ -0,0 +1,182 @@ +/* CSS reset */ +*, *::before, *::after { box-sizing: border-box; } +* { margin: 0; } +body { + font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif; + line-height: 1.5; + background: #1a1a2e; + color: #e0e0e0; + padding: 16px; +} + +h1 { + font-size: 1.4em; + margin-bottom: 16px; + color: #fff; +} + +.test-status { + font-size: 0.7em; + padding: 2px 10px; + border-radius: 4px; + font-weight: bold; + vertical-align: middle; +} + +.test-Running { background: #2563eb; color: #fff; } +.test-Passed { background: #16a34a; color: #fff; } +.test-Failed { background: #dc2626; color: #fff; } + +h2 { + font-size: 1.1em; + margin-bottom: 8px; + color: #ccc; +} + +/* Step progress panel */ +.steps { + background: #16213e; + border: 1px solid #333; + border-radius: 6px; + padding: 12px; + margin-bottom: 16px; +} + +.step { + display: flex; + align-items: center; + gap: 8px; + padding: 4px 8px; + font-family: monospace; + font-size: 13px; + border-radius: 3px; +} + +.step-pending { color: #666; } +.step-running { color: #4af; font-weight: bold; background: rgba(68, 170, 255, 0.1); } +.step-done { color: #4a4; } +.step-failed { color: #f44; font-weight: bold; background: rgba(255, 68, 68, 0.1); } + +.step-icon { width: 1.2em; text-align: center; } +.step-name { flex: 1; } +.step-time { color: #666; font-size: 12px; min-width: 6em; text-align: right; } + +/* VM card grid */ +.vm-grid { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(400px, 1fr)); + gap: 12px; + margin-bottom: 16px; +} + +.vm-card { + background: #16213e; + border: 1px solid #333; + border-radius: 6px; + padding: 12px; +} + +.vm-header { + display: flex; + align-items: center; + gap: 8px; + margin-bottom: 8px; +} + +.vm-name { + font-weight: bold; + font-size: 1.1em; + color: #fff; +} + +.vm-os { + font-size: 0.8em; + background: #333; + padding: 1px 6px; + border-radius: 3px; + color: #aaa; +} + +.vm-status { + display: flex; + flex-direction: column; + gap: 2px; + margin-bottom: 8px; + font-family: monospace; + font-size: 13px; +} + +.vm-status-line { + display: flex; + gap: 8px; +} + +.vm-status-label { + color: #888; + min-width: 7em; +} + +.vm-status-value { + color: #4af; +} + +/* VM display screenshot */ +.screenshot:empty { display: none; } +.screenshot { + margin-bottom: 4px; +} +.screenshot img { + width: 100%; + height: auto; + display: block; + border-radius: 4px; + border: 1px solid #222; + cursor: pointer; +} + +/* Console output */ +.console { + background: #0a0a0a; + color: #ccc; + font-family: "Cascadia Code", "Fira Code", "Consolas", monospace; + font-size: 11px; + line-height: 1.3; + max-height: 300px; + overflow-y: auto; + white-space: pre-wrap; + word-break: break-all; + padding: 8px; + border-radius: 4px; + border: 1px solid #222; +} + +/* Event log */ +.event-log { + background: #16213e; + border: 1px solid #333; + border-radius: 6px; + padding: 12px; +} + +.events { + max-height: 300px; + overflow-y: auto; +} + +.event { + font-family: monospace; + font-size: 12px; + padding: 1px 0; + border-bottom: 1px solid #1a1a2e; +} + +.event-time { color: #666; } +.event-node { color: #4af; font-weight: bold; } +.event-msg { color: #ccc; } +.event-detail { color: #888; } + +.event-dhcp_discover .event-msg, +.event-dhcp_request .event-msg { color: #fa4; } +.event-dhcp_offer .event-msg, +.event-dhcp_ack .event-msg { color: #4f4; } +.event-step_changed .event-msg { color: #aaf; } diff --git a/tstest/natlab/vmtest/cloudinit.go b/tstest/natlab/vmtest/cloudinit.go new file mode 100644 index 0000000000000..a00f849ba81d7 --- /dev/null +++ b/tstest/natlab/vmtest/cloudinit.go @@ -0,0 +1,188 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package vmtest + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/kdomanski/iso9660" +) + +// createCloudInitISO creates a cidata seed ISO for the given cloud VM node. +// For Linux VMs, the ISO contains meta-data, user-data, and network-config. +// For FreeBSD VMs, the ISO contains meta-data and user-data only (nuageinit +// doesn't use netplan-style network-config; DHCP is enabled in rc.conf). +func (e *Env) createCloudInitISO(n *Node) (string, error) { + metaData := fmt.Sprintf("instance-id: %s\nlocal-hostname: %s\n", n.name, n.name) + userData := e.generateUserData(n) + + files := map[string]string{ + "meta-data": metaData, + "user-data": userData, + } + + // Linux cloud-init needs network-config to configure interfaces before + // systemd-networkd-wait-online blocks boot. + if n.os.GOOS() == "linux" { + files["network-config"] = `version: 2 +ethernets: + primary: + match: + macaddress: "` + n.vnetNode.NICMac(0).String() + `" + dhcp4: true + dhcp4-overrides: + route-metric: 100 + optional: true + secondary: + match: + name: "en*" + dhcp4: true + dhcp4-overrides: + route-metric: 200 + optional: true +` + } + + iw, err := iso9660.NewWriter() + if err != nil { + return "", fmt.Errorf("creating ISO writer: %w", err) + } + defer iw.Cleanup() + + for name, content := range files { + if err := iw.AddFile(strings.NewReader(content), name); err != nil { + return "", fmt.Errorf("adding %s to ISO: %w", name, err) + } + } + + isoPath := filepath.Join(e.tempDir, n.name+"-seed.iso") + f, err := os.Create(isoPath) + if err != nil { + return "", err + } + defer f.Close() + if err := iw.WriteTo(f, "cidata"); err != nil { + return "", fmt.Errorf("writing seed ISO: %w", err) + } + return isoPath, nil +} + +// generateUserData creates the cloud-init user-data (#cloud-config) for a node. +func (e *Env) generateUserData(n *Node) string { + switch n.os.GOOS() { + case "linux": + return e.generateLinuxUserData(n) + case "freebsd": + return e.generateFreeBSDUserData(n) + default: + panic(fmt.Sprintf("unsupported GOOS %q for cloud-init user-data", n.os.GOOS())) + } +} + +// generateLinuxUserData creates Linux cloud-init user-data (#cloud-config) for a node. +func (e *Env) generateLinuxUserData(n *Node) string { + var ud strings.Builder + ud.WriteString("#cloud-config\n") + + // Enable root SSH login for debugging via the debug NIC. + ud.WriteString("ssh_pwauth: true\n") + ud.WriteString("disable_root: false\n") + ud.WriteString("users:\n") + ud.WriteString(" - name: root\n") + ud.WriteString(" lock_passwd: false\n") + ud.WriteString(" plain_text_passwd: root\n") + // Also inject the host's SSH key if available. + if pubkey, err := os.ReadFile("/tmp/vmtest_key.pub"); err == nil { + ud.WriteString(fmt.Sprintf(" ssh_authorized_keys:\n - %s\n", strings.TrimSpace(string(pubkey)))) + } + + ud.WriteString("runcmd:\n") + + // Remove the default route from the debug NIC (enp0s4) so traffic goes through vnet. + // The debug NIC is only for SSH access from the host. + ud.WriteString(" - [\"/bin/sh\", \"-c\", \"ip route del default via 10.0.2.2 dev enp0s4 2>/dev/null || true\"]\n") + + // Download binaries from the files.tailscale VIP (52.52.0.6). + // Use the IP directly to avoid DNS resolution issues during early boot. + binDir := n.os.GOOS() + "_" + n.os.GOARCH() + for _, bin := range []string{"tailscaled", "tailscale", "tta"} { + fmt.Fprintf(&ud, " - [\"/bin/sh\", \"-c\", \"curl -v --retry 10 --retry-delay 2 --retry-all-errors -o /usr/local/bin/%s http://52.52.0.6/%s/%s 2>&1\"]\n", bin, binDir, bin) + } + ud.WriteString(" - [\"chmod\", \"+x\", \"/usr/local/bin/tailscaled\", \"/usr/local/bin/tailscale\", \"/usr/local/bin/tta\"]\n") + + // Enable IP forwarding for subnet routers. + if n.advertiseRoutes != "" { + ud.WriteString(" - [\"sysctl\", \"-w\", \"net.ipv4.ip_forward=1\"]\n") + ud.WriteString(" - [\"sysctl\", \"-w\", \"net.ipv6.conf.all.forwarding=1\"]\n") + } + + // Start tailscaled in the background. --statedir provides a VarRoot so + // features like Taildrop (which needs a place to stash incoming files) + // have a directory to work with. + ud.WriteString(" - [\"mkdir\", \"-p\", \"/var/lib/tailscale\"]\n") + ud.WriteString(" - [\"/bin/sh\", \"-c\", \"/usr/local/bin/tailscaled --state=mem: --statedir=/var/lib/tailscale &\"]\n") + ud.WriteString(" - [\"sleep\", \"2\"]\n") + + // Start tta (Tailscale Test Agent). + ud.WriteString(" - [\"/bin/sh\", \"-c\", \"/usr/local/bin/tta &\"]\n") + + return ud.String() +} + +// generateFreeBSDUserData creates FreeBSD nuageinit user-data (#cloud-config) +// for a node. FreeBSD's nuageinit supports a subset of cloud-init directives +// including runcmd, which runs after networking is up. +// +// IMPORTANT: nuageinit's runcmd only supports string entries, not the YAML +// array form that Linux cloud-init supports. Each entry must be a plain string +// that gets passed to /bin/sh -c. +func (e *Env) generateFreeBSDUserData(n *Node) string { + var ud strings.Builder + ud.WriteString("#cloud-config\n") + ud.WriteString("ssh_pwauth: true\n") + + ud.WriteString("runcmd:\n") + + // /usr/local/bin may not exist on a fresh FreeBSD cloud image (it's + // created when the first package is installed). + ud.WriteString(" - \"mkdir -p /usr/local/bin\"\n") + + // Remove the default route via the debug NIC's SLIRP gateway so that + // traffic goes through the vnet NICs. The debug NIC is only for SSH. + ud.WriteString(" - \"route delete default 10.0.2.2 2>/dev/null || true\"\n") + + // Download binaries from the files.tailscale VIP (52.52.0.6). + // FreeBSD's fetch(1) is part of the base system (no curl needed). + // Retry in a loop since the file server may not be ready immediately. + binDir := n.os.GOOS() + "_" + n.os.GOARCH() + for _, bin := range []string{"tailscaled", "tailscale", "tta"} { + fmt.Fprintf(&ud, " - \"n=0; while [ $n -lt 10 ]; do fetch -o /usr/local/bin/%s http://52.52.0.6/%s/%s && break; n=$((n+1)); sleep 2; done\"\n", bin, binDir, bin) + } + ud.WriteString(" - \"chmod +x /usr/local/bin/tailscaled /usr/local/bin/tailscale /usr/local/bin/tta\"\n") + + // Enable IP forwarding for subnet routers. + // This is currently a noop as of 2026-04-08 because FreeBSD uses + // gvisor netstack for subnet routing until + // https://github.com/tailscale/tailscale/issues/5573 etc are fixed. + if n.advertiseRoutes != "" { + ud.WriteString(" - \"sysctl net.inet.ip.forwarding=1\"\n") + ud.WriteString(" - \"sysctl net.inet6.ip6.forwarding=1\"\n") + } + + // Start tailscaled and tta in the background. + // Set PATH to include /usr/local/bin so that tta can find "tailscale" + // (TTA uses exec.Command("tailscale", ...) without a full path). + // --statedir provides a VarRoot so features like Taildrop have a directory. + ud.WriteString(" - \"mkdir -p /var/lib/tailscale\"\n") + ud.WriteString(" - \"export PATH=/usr/local/bin:$PATH && /usr/local/bin/tailscaled --state=mem: --statedir=/var/lib/tailscale &\"\n") + ud.WriteString(" - \"sleep 2\"\n") + + // Start tta (Tailscale Test Agent). + ud.WriteString(" - \"export PATH=/usr/local/bin:$PATH && /usr/local/bin/tta &\"\n") + + return ud.String() +} diff --git a/tstest/natlab/vmtest/images.go b/tstest/natlab/vmtest/images.go new file mode 100644 index 0000000000000..bce5452a4d0b8 --- /dev/null +++ b/tstest/natlab/vmtest/images.go @@ -0,0 +1,223 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package vmtest + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + "log" + "net/http" + "os" + "os/exec" + "path/filepath" + "strings" + + "github.com/ulikunitz/xz" +) + +// OSImage describes a VM operating system image. +type OSImage struct { + Name string + URL string // download URL for the cloud image + SHA256 string // expected SHA256 hash of the image (of the final qcow2, after any decompression) + MemoryMB int // RAM for the VM + IsGokrazy bool // true for gokrazy images (different QEMU setup) + IsMacOS bool // true for macOS images (launched via tailmac, not QEMU) +} + +// GOOS returns the Go OS name for this image. +func (img OSImage) GOOS() string { + if img.IsMacOS { + return "darwin" + } + if img.IsGokrazy { + return "linux" + } + if strings.HasPrefix(img.Name, "freebsd") { + return "freebsd" + } + return "linux" +} + +// GOARCH returns the Go architecture name for this image. +func (img OSImage) GOARCH() string { + if img.IsMacOS { + return "arm64" + } + return "amd64" +} + +var ( + // Gokrazy is a minimal Tailscale appliance image built from the gokrazy/natlabapp directory. + Gokrazy = OSImage{ + Name: "gokrazy", + IsGokrazy: true, + MemoryMB: 384, + } + + // Ubuntu2404 is Ubuntu 24.04 LTS (Noble Numbat) cloud image. + Ubuntu2404 = OSImage{ + Name: "ubuntu-24.04", + URL: "https://cloud-images.ubuntu.com/noble/current/noble-server-cloudimg-amd64.img", + MemoryMB: 1024, + } + + // Debian12 is Debian 12 (Bookworm) generic cloud image. + Debian12 = OSImage{ + Name: "debian-12", + URL: "https://cloud.debian.org/images/cloud/bookworm/latest/debian-12-generic-amd64.qcow2", + MemoryMB: 1024, + } + + // FreeBSD150 is FreeBSD 15.0-RELEASE with BASIC-CLOUDINIT (nuageinit) support. + // The image is distributed as xz-compressed qcow2. + FreeBSD150 = OSImage{ + Name: "freebsd-15.0", + URL: "https://download.freebsd.org/releases/VM-IMAGES/15.0-RELEASE/amd64/Latest/FreeBSD-15.0-RELEASE-amd64-BASIC-CLOUDINIT-ufs.qcow2.xz", + MemoryMB: 1024, + } + + // MacOS is a macOS VM launched via tailmac (Apple Virtualization.framework). + // Uses a Tart pre-built base image (ghcr.io/cirruslabs/macos-tahoe-base) + // which is automatically pulled on first use. Only runs on macOS arm64 hosts. + MacOS = OSImage{ + Name: "macos", + IsMacOS: true, + MemoryMB: 4096, + } +) + +// imageCacheDir returns the directory for cached VM images. +func imageCacheDir() string { + if d := os.Getenv("VMTEST_CACHE_DIR"); d != "" { + return d + } + home, _ := os.UserHomeDir() + return filepath.Join(home, ".cache", "tailscale", "vmtest", "images") +} + +// ensureImage downloads and caches the OS image if not already present. +func ensureImage(ctx context.Context, img OSImage) error { + if img.IsGokrazy { + return nil // gokrazy images are handled separately + } + + cacheDir := imageCacheDir() + if err := os.MkdirAll(cacheDir, 0755); err != nil { + return err + } + + // Use a filename based on the image name. + cachedPath := filepath.Join(cacheDir, img.Name+".qcow2") + if _, err := os.Stat(cachedPath); err == nil { + // If we have a SHA256 to verify, check it. + if img.SHA256 != "" { + if err := verifySHA256(cachedPath, img.SHA256); err != nil { + log.Printf("cached image %s failed SHA256 check, re-downloading: %v", img.Name, err) + os.Remove(cachedPath) + } else { + return nil + } + } else { + return nil // exists, no hash to verify + } + } + + isXZ := strings.HasSuffix(img.URL, ".xz") + log.Printf("downloading %s from %s...", img.Name, img.URL) + + req, err := http.NewRequestWithContext(ctx, "GET", img.URL, nil) + if err != nil { + return fmt.Errorf("downloading %s: %w", img.Name, err) + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return fmt.Errorf("downloading %s: %w", img.Name, err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + return fmt.Errorf("downloading %s: HTTP %s", img.Name, resp.Status) + } + + // Set up the reader pipeline: HTTP body → (optional xz decompress) → file. + var src io.Reader = resp.Body + if isXZ { + xzr, err := xz.NewReader(resp.Body) + if err != nil { + return fmt.Errorf("creating xz reader for %s: %w", img.Name, err) + } + src = xzr + } + + tmpFile := cachedPath + ".tmp" + f, err := os.Create(tmpFile) + if err != nil { + return err + } + defer func() { + f.Close() + os.Remove(tmpFile) + }() + + h := sha256.New() + w := io.MultiWriter(f, h) + if _, err := io.Copy(w, src); err != nil { + return fmt.Errorf("downloading %s: %w", img.Name, err) + } + if err := f.Close(); err != nil { + return err + } + + if img.SHA256 != "" { + got := hex.EncodeToString(h.Sum(nil)) + if got != img.SHA256 { + return fmt.Errorf("SHA256 mismatch for %s: got %s, want %s", img.Name, got, img.SHA256) + } + } + + if err := os.Rename(tmpFile, cachedPath); err != nil { + return err + } + log.Printf("downloaded %s", img.Name) + return nil +} + +// verifySHA256 checks that the file at path has the expected SHA256 hash. +func verifySHA256(path, expected string) error { + f, err := os.Open(path) + if err != nil { + return err + } + defer f.Close() + h := sha256.New() + if _, err := io.Copy(h, f); err != nil { + return err + } + got := hex.EncodeToString(h.Sum(nil)) + if got != expected { + return fmt.Errorf("got %s, want %s", got, expected) + } + return nil +} + +// cachedImagePath returns the filesystem path to the cached image for the given OS. +func cachedImagePath(img OSImage) string { + return filepath.Join(imageCacheDir(), img.Name+".qcow2") +} + +// createOverlay creates a qcow2 overlay image on top of the given base image. +func createOverlay(base, overlay string) error { + out, err := exec.Command("qemu-img", "create", + "-f", "qcow2", + "-F", "qcow2", + "-b", base, + overlay).CombinedOutput() + if err != nil { + return fmt.Errorf("qemu-img create overlay: %v: %s", err, out) + } + return nil +} diff --git a/tstest/natlab/vmtest/qemu.go b/tstest/natlab/vmtest/qemu.go new file mode 100644 index 0000000000000..757657e51f50f --- /dev/null +++ b/tstest/natlab/vmtest/qemu.go @@ -0,0 +1,365 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package vmtest + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net" + "os" + "os/exec" + "path/filepath" + "regexp" + "strconv" + "strings" + "time" + + "tailscale.com/tstest/natlab/vnet" +) + +// gokrazyPlatform boots gokrazy (Linux) VMs via QEMU. +type gokrazyPlatform struct{} + +func (gokrazyPlatform) planSteps(e *Env, n *Node) { + e.Step("Build gokrazy image") + e.Step("Launch QEMU: " + n.name) +} + +func (gokrazyPlatform) boot(ctx context.Context, e *Env, n *Node) error { + e.gokrazyOnce.Do(func() { + step := e.Step("Build gokrazy image") + step.Begin() + if err := e.ensureGokrazy(ctx); err != nil { + step.End(err) + e.t.Fatalf("ensureGokrazy: %v", err) + } + step.End(nil) + }) + + e.ensureQEMUSocket() + + vmStep := e.Step("Launch QEMU: " + n.name) + vmStep.Begin() + if err := e.startGokrazyQEMU(n); err != nil { + vmStep.End(err) + return err + } + vmStep.End(nil) + return nil +} + +// qemuCloudPlatform boots cloud images (Ubuntu, Debian, FreeBSD) via QEMU. +type qemuCloudPlatform struct{} + +func (qemuCloudPlatform) planSteps(e *Env, n *Node) { + e.Step(fmt.Sprintf("Compile %s_%s binaries", n.os.GOOS(), n.os.GOARCH())) + e.Step(fmt.Sprintf("Prepare %s image", n.os.Name)) + e.Step("Launch QEMU: " + n.name) +} + +func (qemuCloudPlatform) boot(ctx context.Context, e *Env, n *Node) error { + goos, goarch := n.os.GOOS(), n.os.GOARCH() + + e.ensureCompiled(ctx, goos, goarch) + + if err := e.ensureImage(ctx, n.os); err != nil { + return err + } + + e.ensureQEMUSocket() + + vmStep := e.Step("Launch QEMU: " + n.name) + vmStep.Begin() + if err := e.startCloudQEMU(n); err != nil { + vmStep.End(err) + return err + } + vmStep.End(nil) + return nil +} + +// startGokrazyQEMU launches a QEMU process for a gokrazy node. +// This follows the same pattern as tstest/integration/nat/nat_test.go. +func (e *Env) startGokrazyQEMU(n *Node) error { + disk := filepath.Join(e.tempDir, fmt.Sprintf("%s.qcow2", n.name)) + if err := createOverlay(e.gokrazyBase, disk); err != nil { + return err + } + + var envBuf bytes.Buffer + for _, env := range n.vnetNode.Env() { + fmt.Fprintf(&envBuf, " tailscaled.env=%s=%s", env.Key, env.Value) + } + sysLogAddr := net.JoinHostPort(vnet.FakeSyslogIPv4().String(), "995") + if n.vnetNode.IsV6Only() { + sysLogAddr = net.JoinHostPort(vnet.FakeSyslogIPv6().String(), "995") + } + + logPath := filepath.Join(e.tempDir, n.name+".log") + + args := []string{ + "-M", "microvm,isa-serial=off", + "-m", fmt.Sprintf("%dM", n.os.MemoryMB), + "-nodefaults", "-no-user-config", "-nographic", + "-kernel", e.gokrazyKernel, + "-append", "console=hvc0 root=PARTUUID=60c24cc1-f3f9-427a-8199-76baa2d60001/PARTNROFF=1 ro init=/gokrazy/init panic=10 oops=panic pci=off nousb tsc=unstable clocksource=hpet gokrazy.remote_syslog.target=" + sysLogAddr + " tailscale-tta=1" + envBuf.String(), + "-drive", "id=blk0,file=" + disk + ",format=qcow2", + "-device", "virtio-blk-device,drive=blk0", + "-device", "virtio-serial-device", + "-device", "virtio-rng-device", + "-chardev", "file,id=virtiocon0,path=" + logPath, + "-device", "virtconsole,chardev=virtiocon0", + } + + // Add network devices — one per NIC. + for i := range n.vnetNode.NumNICs() { + mac := n.vnetNode.NICMac(i) + netdevID := fmt.Sprintf("net%d", i) + args = append(args, + "-netdev", fmt.Sprintf("stream,id=%s,addr.type=unix,addr.path=%s", netdevID, e.sockAddr), + "-device", fmt.Sprintf("virtio-net-device,netdev=%s,mac=%s", netdevID, mac), + ) + } + + return e.launchQEMU(n.name, logPath, args) +} + +// startCloudQEMU launches a QEMU process for a cloud image (Ubuntu, Debian, FreeBSD, etc). +func (e *Env) startCloudQEMU(n *Node) error { + basePath := cachedImagePath(n.os) + disk := filepath.Join(e.tempDir, fmt.Sprintf("%s.qcow2", n.name)) + if err := createOverlay(basePath, disk); err != nil { + return err + } + + // Create a seed ISO with cloud-init config (meta-data, user-data, network-config). + // This MUST be a local ISO (not HTTP) so cloud-init reads network-config during + // init-local, before systemd-networkd-wait-online blocks boot. + seedISO, err := e.createCloudInitISO(n) + if err != nil { + return fmt.Errorf("creating cloud-init ISO: %w", err) + } + + logPath := filepath.Join(e.tempDir, n.name+".log") + qmpSock := filepath.Join(e.tempDir, n.name+"-qmp.sock") + + args := []string{ + "-machine", "q35,accel=kvm", + "-m", fmt.Sprintf("%dM", n.os.MemoryMB), + "-cpu", "host", + "-smp", "2", + "-display", "none", + "-drive", fmt.Sprintf("file=%s,if=virtio", disk), + "-drive", fmt.Sprintf("file=%s,if=virtio,media=cdrom,readonly=on", seedISO), + "-smbios", "type=1,serial=ds=nocloud", + "-serial", "file:" + logPath, + "-qmp", "unix:" + qmpSock + ",server,nowait", + } + + // Add network devices — one per NIC. + // romfile="" disables the iPXE option ROM entirely, saving ~5s per NIC at boot + // and avoiding "duplicate fw_cfg file name" errors with multiple NICs. + for i := range n.vnetNode.NumNICs() { + mac := n.vnetNode.NICMac(i) + netdevID := fmt.Sprintf("net%d", i) + args = append(args, + "-netdev", fmt.Sprintf("stream,id=%s,addr.type=unix,addr.path=%s", netdevID, e.sockAddr), + "-device", fmt.Sprintf("virtio-net-pci,netdev=%s,mac=%s,romfile=", netdevID, mac), + ) + } + + // Add a debug NIC with user-mode networking for SSH access from the host. + // Use port 0 so the OS picks a free port; we query the actual port via QMP after launch. + args = append(args, + "-netdev", "user,id=debug0,hostfwd=tcp:127.0.0.1:0-:22", + "-device", "virtio-net-pci,netdev=debug0,romfile=", + ) + + if err := e.launchQEMU(n.name, logPath, args); err != nil { + return err + } + + // Query QMP to find the actual SSH port that QEMU allocated. + port, err := qmpQueryHostFwd(qmpSock) + if err != nil { + return fmt.Errorf("querying SSH port via QMP: %w", err) + } + n.sshPort = port + e.t.Logf("[%s] SSH debug: ssh -p %d root@127.0.0.1 (password: root)", n.name, port) + return nil +} + +// launchQEMU starts a qemu-system-x86_64 process with the given args. +// VM console output goes to logPath (via QEMU's -serial or -chardev). +// QEMU's own stdout/stderr go to logPath.qemu for diagnostics. +func (e *Env) launchQEMU(name, logPath string, args []string) error { + cmd := exec.Command("qemu-system-x86_64", args...) + // Send stdout/stderr to the log file for any QEMU diagnostic messages. + // Stdin must be /dev/null to prevent QEMU from trying to read. + devNull, err := os.Open(os.DevNull) + if err != nil { + return fmt.Errorf("open /dev/null: %w", err) + } + cmd.Stdin = devNull + qemuLog, err := os.Create(logPath + ".qemu") + if err != nil { + devNull.Close() + return err + } + cmd.Stdout = qemuLog + cmd.Stderr = qemuLog + if err := cmd.Start(); err != nil { + devNull.Close() + qemuLog.Close() + return fmt.Errorf("qemu for %s: %w", name, err) + } + e.t.Logf("launched QEMU for %s (pid %d), log: %s", name, cmd.Process.Pid, logPath) + e.qemuProcs = append(e.qemuProcs, cmd) + + // Start tailing the VM console log for the web UI. + if e.ctx != nil { + go e.tailLogFile(e.ctx, name, logPath) + } + e.t.Cleanup(func() { + cmd.Process.Kill() + cmd.Wait() + devNull.Close() + qemuLog.Close() + // Dump tail of VM log on failure for debugging. + if e.t.Failed() { + if data, err := os.ReadFile(logPath); err == nil { + lines := bytes.Split(data, []byte("\n")) + start := 0 + if len(lines) > 50 { + start = len(lines) - 50 + } + e.t.Logf("=== last 50 lines of %s log ===", name) + for _, line := range lines[start:] { + e.t.Logf("[%s] %s", name, line) + } + } + } + }) + return nil +} + +// qmpQueryHostFwd connects to a QEMU QMP socket and queries the host port +// assigned to the first TCP host forward rule (the SSH debug port). +func qmpQueryHostFwd(sockPath string) (int, error) { + // Wait for the QMP socket to appear. + var conn net.Conn + for range 50 { + var err error + conn, err = net.Dial("unix", sockPath) + if err == nil { + break + } + time.Sleep(100 * time.Millisecond) + } + if conn == nil { + return 0, fmt.Errorf("QMP socket %s not available", sockPath) + } + defer conn.Close() + conn.SetDeadline(time.Now().Add(5 * time.Second)) + + // Read the QMP greeting. + var greeting json.RawMessage + dec := json.NewDecoder(conn) + if err := dec.Decode(&greeting); err != nil { + return 0, fmt.Errorf("reading QMP greeting: %w", err) + } + + // Send qmp_capabilities to initialize. + fmt.Fprintf(conn, `{"execute":"qmp_capabilities"}`+"\n") + var capsResp json.RawMessage + if err := dec.Decode(&capsResp); err != nil { + return 0, fmt.Errorf("reading qmp_capabilities response: %w", err) + } + + // Query "info usernet" via human-monitor-command. + fmt.Fprintf(conn, `{"execute":"human-monitor-command","arguments":{"command-line":"info usernet"}}`+"\n") + var hmpResp struct { + Return string `json:"return"` + } + if err := dec.Decode(&hmpResp); err != nil { + return 0, fmt.Errorf("reading info usernet response: %w", err) + } + + // Parse the port from output like: + // TCP[HOST_FORWARD] 12 127.0.0.1 35323 10.0.2.15 22 + re := regexp.MustCompile(`TCP\[HOST_FORWARD\]\s+\d+\s+127\.0\.0\.1\s+(\d+)\s+`) + m := re.FindStringSubmatch(hmpResp.Return) + if m == nil { + return 0, fmt.Errorf("no hostfwd port found in: %s", hmpResp.Return) + } + return strconv.Atoi(m[1]) +} + +// tailLogFile tails a VM's serial console log file and publishes each line +// as an EventConsoleOutput to the event bus for the web UI. +func (e *Env) tailLogFile(ctx context.Context, name, logPath string) { + // Wait for the file to appear (QEMU may not have created it yet). + var f *os.File + for { + var err error + f, err = os.Open(logPath) + if err == nil { + break + } + select { + case <-ctx.Done(): + return + case <-time.After(100 * time.Millisecond): + } + } + defer f.Close() + + // Read the file in a loop, tracking our position manually. + // We can't use bufio.Scanner because it caches EOF and won't + // pick up new data appended by QEMU after the first EOF. + var buf []byte + var partial string // incomplete line (no trailing newline yet) + readBuf := make([]byte, 4096) + for { + n, err := f.Read(readBuf) + if n > 0 { + buf = append(buf, readBuf[:n]...) + // Split into complete lines. + for { + idx := bytes.IndexByte(buf, '\n') + if idx < 0 { + break + } + line := partial + string(buf[:idx]) + partial = "" + buf = buf[idx+1:] + // Strip trailing \r from serial consoles. + line = strings.TrimRight(line, "\r") + if line == "" { + continue + } + e.appendConsoleLine(name, line) + e.eventBus.Publish(VMEvent{ + NodeName: name, + Type: EventConsoleOutput, + Message: line, + }) + } + if len(buf) > 0 { + partial = string(buf) + buf = buf[:0] + } + } + if err != nil || n == 0 { + // EOF or error — wait for more data. + select { + case <-ctx.Done(): + return + case <-time.After(100 * time.Millisecond): + } + } + } +} diff --git a/tstest/natlab/vmtest/tailmac.go b/tstest/natlab/vmtest/tailmac.go new file mode 100644 index 0000000000000..167feeb04b5c1 --- /dev/null +++ b/tstest/natlab/vmtest/tailmac.go @@ -0,0 +1,736 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package vmtest + +import ( + "bufio" + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/netip" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + "time" + + "golang.org/x/crypto/ssh" +) + +// macPlatform boots macOS VMs via Tart base images and tailmac Host.app. +type macPlatform struct{} + +func (macPlatform) planSteps(e *Env, n *Node) { + e.Step("Prepare macOS Tart image") + e.Step("Launch macOS VM: " + n.name) +} + +func (macPlatform) boot(ctx context.Context, e *Env, n *Node) error { + imgStep := e.Step("Prepare macOS Tart image") + e.macosSnapshotOnce.Do(func() { + imgStep.Begin() + e.macosSnapshot = ensureSnapshot(e.t) + imgStep.End(nil) + }) + + e.ensureDgramSocket() + + vmStep := e.Step("Launch macOS VM: " + n.name) + vmStep.Begin() + if err := e.startTailMacVM(n); err != nil { + vmStep.End(err) + return err + } + vmStep.End(nil) + return nil +} + +const tartImage = "ghcr.io/cirruslabs/macos-tahoe-base:latest" + +// macOSSnapshotCodeVersion is bumped when the snapshot preparation logic +// changes in a way that invalidates old snapshots. Old snapshots with a +// different version are cleaned up automatically. +const macOSSnapshotCodeVersion = 5 + +// tartConfig is the subset of Tart's config.json we need. +type tartConfig struct { + HardwareModel string `json:"hardwareModel"` // base64 + ECID string `json:"ecid"` // base64 +} + +// tartManifest is the subset of Tart's OCI manifest.json we need. +type tartManifest struct { + Config struct { + Digest string `json:"digest"` // e.g. "sha256:3a6cb4eb6201..." + } `json:"config"` +} + +// ensureTartImage checks that the Tart base image is available, pulling it +// if necessary. Returns the path to the OCI cache directory containing +// disk.img, nvram.bin, config.json, and manifest.json. +func ensureTartImage(t testing.TB) string { + if _, err := exec.LookPath("tart"); err != nil { + t.Skip("tart not installed; skipping macOS VM test") + } + + home, err := os.UserHomeDir() + if err != nil { + t.Fatalf("UserHomeDir: %v", err) + } + + ociDir := filepath.Join(home, ".tart", "cache", "OCIs", + "ghcr.io", "cirruslabs", "macos-tahoe-base", "latest") + if _, err := os.Stat(filepath.Join(ociDir, "disk.img")); err == nil { + return ociDir + } + + t.Logf("pulling Tart image %s ...", tartImage) + cmd := exec.Command("tart", "pull", tartImage) + cmd.Stdout = os.Stderr + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + t.Fatalf("tart pull: %v", err) + } + + if _, err := os.Stat(filepath.Join(ociDir, "disk.img")); err == nil { + return ociDir + } + t.Fatalf("tart pull succeeded but image not found at %s", ociDir) + return "" +} + +// snapshotCacheKey computes a cache key for the macOS VM snapshot. +// The key combines the image name, the first 12 hex chars of the Tart +// config digest (changes when the upstream image is updated), and the +// snapshot code version (changes when our prep logic changes). +func snapshotCacheKey(tartDir string) (string, error) { + manifestPath := filepath.Join(tartDir, "manifest.json") + data, err := os.ReadFile(manifestPath) + if err != nil { + return "", fmt.Errorf("reading manifest: %w", err) + } + var m tartManifest + if err := json.Unmarshal(data, &m); err != nil { + return "", fmt.Errorf("parsing manifest: %w", err) + } + digest := m.Config.Digest + // Strip "sha256:" prefix and take first 12 hex chars. + digest = strings.TrimPrefix(digest, "sha256:") + if len(digest) > 12 { + digest = digest[:12] + } + return fmt.Sprintf("snap-tahoe-%s-v%d", digest, macOSSnapshotCodeVersion), nil +} + +// macosVMBaseDir returns ~/.cache/tailscale/vmtest/macos/, the directory +// where Host.app expects to find VM directories by ID. +func macosVMBaseDir() (string, error) { + home, err := os.UserHomeDir() + if err != nil { + return "", err + } + return filepath.Join(home, ".cache", "tailscale", "vmtest", "macos"), nil +} + +// cleanOldSnapshots removes any snapshot directories for the given image +// prefix (e.g. "snap-tahoe") that don't match the current cache key. +func cleanOldSnapshots(t testing.TB, imagePrefix, currentKey string) { + base, err := macosVMBaseDir() + if err != nil { + return + } + matches, _ := filepath.Glob(filepath.Join(base, imagePrefix+"-*")) + currentPath := filepath.Join(base, currentKey) + for _, m := range matches { + if m != currentPath { + t.Logf("removing stale snapshot: %s", filepath.Base(m)) + os.RemoveAll(m) + } + } +} + +// ensureSnapshot returns the path to a cached macOS VM snapshot, creating +// one if necessary. The snapshot contains a fully booted VM with +// SaveFile.vzvmsave ready for fast restore. +func ensureSnapshot(t testing.TB) string { + tartDir := ensureTartImage(t) + + key, err := snapshotCacheKey(tartDir) + if err != nil { + t.Fatalf("snapshot cache key: %v", err) + } + + base, err := macosVMBaseDir() + if err != nil { + t.Fatalf("macOS VM base dir: %v", err) + } + os.MkdirAll(base, 0755) + + snapDir := filepath.Join(base, key) + saveFile := filepath.Join(snapDir, "SaveFile.vzvmsave") + if _, err := os.Stat(saveFile); err == nil { + t.Logf("using cached macOS snapshot: %s", key) + return snapDir + } + + // Clean up old snapshots for this image. + cleanOldSnapshots(t, "snap-tahoe", key) + + t.Logf("preparing macOS snapshot: %s (this takes ~30s on first run)", key) + if err := prepareSnapshot(t, tartDir, snapDir); err != nil { + os.RemoveAll(snapDir) + t.Fatalf("preparing snapshot: %v", err) + } + return snapDir +} + +// prepareSnapshot creates a new macOS VM snapshot by booting the Tart base +// image with a NAT NIC, waiting for SSH, and saving VM state. +func prepareSnapshot(t testing.TB, tartDir, snapDir string) error { + // The vmID must match the directory name under macosVMBaseDir + // because Host.app looks up VM files at //. + snapID := filepath.Base(snapDir) + + if err := cloneTartToTailmac(tartDir, snapDir, snapID, "52:cc:cc:cc:ce:01", "/dev/null"); err != nil { + return fmt.Errorf("cloning tart: %w", err) + } + + modRoot, err := findModRoot() + if err != nil { + return err + } + tailmacDir := filepath.Join(modRoot, "tstest", "tailmac", "bin") + hostBin := filepath.Join(tailmacDir, "Host.app", "Contents", "MacOS", "Host") + if _, err := os.Stat(hostBin); err != nil { + return fmt.Errorf("Host.app not found at %s; run 'make all' in tstest/tailmac/", hostBin) + } + + // Host.app reads VM files from ~/.cache/tailscale/vmtest/macos//. + // Our snapDir is already under that tree, and the config.json vmID matches. + cmd := exec.Command(hostBin, "run", "--id", snapID, "--headless", "--nat-nic") + cmd.Env = append(os.Environ(), "NSUnbufferedIO=YES") + + logPath := snapDir + ".prep.log" + logFile, err := os.Create(logPath) + if err != nil { + return err + } + defer logFile.Close() + cmd.Stdout = logFile + cmd.Stderr = logFile + devNull, _ := os.Open(os.DevNull) + cmd.Stdin = devNull + defer devNull.Close() + + if err := cmd.Start(); err != nil { + return fmt.Errorf("starting Host.app: %w", err) + } + t.Logf("snapshot prep: launched Host.app (pid %d)", cmd.Process.Pid) + + // Wait for SSH to become available via the NAT NIC. + // The VM gets an IP from macOS's vmnet DHCP (typically 192.168.64.x). + ip, err := waitForVMIP(t, "52:cc:cc:cc:ce:01", 60*time.Second) + if err != nil { + cmd.Process.Kill() + cmd.Wait() + return fmt.Errorf("waiting for VM IP: %w", err) + } + t.Logf("snapshot prep: VM IP is %s, waiting for SSH...", ip) + + sc, err := waitForSSH(ip, 60*time.Second) + if err != nil { + cmd.Process.Kill() + cmd.Wait() + return fmt.Errorf("waiting for SSH: %w", err) + } + t.Logf("snapshot prep: SSH connected") + + // Compile and install TTA in the macOS VM. + t.Logf("snapshot prep: installing TTA...") + if err := installTTA(t, sc); err != nil { + sc.Close() + cmd.Process.Kill() + cmd.Wait() + return fmt.Errorf("installing TTA: %w", err) + } + sc.Close() + + // Save VM state by sending SIGINT. + t.Logf("snapshot prep: saving VM state...") + cmd.Process.Signal(os.Interrupt) + done := make(chan error, 1) + go func() { done <- cmd.Wait() }() + select { + case err := <-done: + if err != nil { + // Host.app exits 0 after saving state, non-zero is unexpected. + t.Logf("snapshot prep: Host.app exited with: %v", err) + } + case <-time.After(60 * time.Second): + cmd.Process.Kill() + <-done + return fmt.Errorf("Host.app did not exit after SIGINT") + } + + // Verify the save file was created. + saveFile := filepath.Join(snapDir, "SaveFile.vzvmsave") + if _, err := os.Stat(saveFile); err != nil { + return fmt.Errorf("SaveFile.vzvmsave not found after prep") + } + t.Logf("snapshot prep: done, saved to %s", filepath.Base(snapDir)) + os.Remove(logPath) + return nil +} + +// installTTA compiles TTA for darwin/arm64 and installs it in the macOS VM +// as a LaunchDaemon via SSH/SCP. +func installTTA(t testing.TB, sc *ssh.Client) error { + modRoot, err := findModRoot() + if err != nil { + return err + } + + // Compile TTA for the macOS VM. + tmpDir := t.TempDir() + ttaBin := filepath.Join(tmpDir, "tta") + t.Logf("snapshot prep: compiling TTA for darwin/arm64...") + buildCmd := exec.Command("go", "build", "-o", ttaBin, "./cmd/tta") + buildCmd.Dir = modRoot + buildCmd.Env = append(os.Environ(), "GOOS=darwin", "GOARCH=arm64", "CGO_ENABLED=0") + if out, err := buildCmd.CombinedOutput(); err != nil { + return fmt.Errorf("compiling TTA: %v\n%s", err, out) + } + + // Read the binary. + ttaData, err := os.ReadFile(ttaBin) + if err != nil { + return fmt.Errorf("reading TTA binary: %w", err) + } + t.Logf("snapshot prep: TTA binary is %d bytes", len(ttaData)) + + // SCP the TTA binary to the VM via a temp file (admin user can't write /usr/local/bin directly). + if err := scpFile(sc, ttaData, "/tmp/tta", 0755); err != nil { + return fmt.Errorf("uploading TTA: %w", err) + } + if err := runSSHCmd(sc, "echo admin | sudo -S mv /tmp/tta /usr/local/bin/tta"); err != nil { + return fmt.Errorf("moving TTA to /usr/local/bin: %w", err) + } + + // Install the LaunchDaemon plist. + plist := ` + + + + Label + com.tailscale.tta + ProgramArguments + + /usr/local/bin/tta + + RunAtLoad + + KeepAlive + + StandardOutPath + /tmp/tta.log + StandardErrorPath + /tmp/tta.log + + +` + if err := scpFile(sc, []byte(plist), "/tmp/com.tailscale.tta.plist", 0644); err != nil { + return fmt.Errorf("uploading plist: %w", err) + } + if err := runSSHCmd(sc, "echo admin | sudo -S mv /tmp/com.tailscale.tta.plist /Library/LaunchDaemons/ && echo admin | sudo -S chown root:wheel /Library/LaunchDaemons/com.tailscale.tta.plist"); err != nil { + return fmt.Errorf("installing plist: %w", err) + } + + // Load the LaunchDaemon. + if err := runSSHCmd(sc, "echo admin | sudo -S launchctl load /Library/LaunchDaemons/com.tailscale.tta.plist"); err != nil { + return fmt.Errorf("loading LaunchDaemon: %w", err) + } + + // Wait for TTA to start. + for range 20 { + if err := runSSHCmd(sc, "pgrep -x tta"); err == nil { + break + } + time.Sleep(250 * time.Millisecond) + } + if err := runSSHCmd(sc, "pgrep -x tta"); err != nil { + return fmt.Errorf("TTA not running after install: %w", err) + } + t.Logf("snapshot prep: TTA installed and running") + return nil +} + +// scpFile uploads data to a remote path via SSH/SCP. +func scpFile(sc *ssh.Client, data []byte, remotePath string, mode os.FileMode) error { + sess, err := sc.NewSession() + if err != nil { + return err + } + defer sess.Close() + + // Use a simple shell command to write the file. + cmd := fmt.Sprintf("cat > %s && chmod %o %s", remotePath, mode, remotePath) + sess.Stdin = bytes.NewReader(data) + out, err := sess.CombinedOutput(cmd) + if err != nil { + return fmt.Errorf("%s: %v: %s", cmd, err, out) + } + return nil +} + +// runSSHCmd runs a command on the SSH client and returns an error if it fails. +func runSSHCmd(sc *ssh.Client, cmd string) error { + sess, err := sc.NewSession() + if err != nil { + return err + } + defer sess.Close() + out, err := sess.CombinedOutput(cmd) + if err != nil { + return fmt.Errorf("%s: %v: %s", cmd, err, out) + } + return nil +} + +// waitForVMIP polls /var/db/dhcpd_leases for a DHCP lease matching the +// given MAC address (from macOS's vmnet NAT). Returns the IP. +func waitForVMIP(t testing.TB, mac string, timeout time.Duration) (string, error) { + // Normalize MAC format: vmnet leases use "1,xx:xx:xx:xx:xx:xx" format + // with leading zeros stripped from each octet (e.g. "1,52:cc:cc:cc:ce:1" + // instead of "1,52:cc:cc:cc:ce:01"). + mac = strings.ToLower(mac) + parts := strings.Split(mac, ":") + for i, p := range parts { + parts[i] = strings.TrimLeft(p, "0") + if parts[i] == "" { + parts[i] = "0" + } + } + leaseMAC := "1," + strings.Join(parts, ":") + + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + data, err := os.ReadFile("/var/db/dhcpd_leases") + if err == nil { + // Parse the plist-like lease file. + lines := strings.Split(string(data), "\n") + var currentIP string + for _, line := range lines { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "ip_address=") { + currentIP = strings.TrimPrefix(line, "ip_address=") + } + if strings.HasPrefix(line, "hw_address=") { + hw := strings.TrimPrefix(line, "hw_address=") + if strings.ToLower(hw) == leaseMAC && currentIP != "" { + return currentIP, nil + } + } + if line == "}" { + currentIP = "" + } + } + } + time.Sleep(time.Second) + } + return "", fmt.Errorf("no DHCP lease for MAC %s after %v", mac, timeout) +} + +// waitForSSH retries SSH connection to the given IP until it succeeds or +// the timeout expires. +func waitForSSH(ip string, timeout time.Duration) (*ssh.Client, error) { + deadline := time.Now().Add(timeout) + addr := net.JoinHostPort(ip, "22") + cfg := &ssh.ClientConfig{ + User: "admin", + Auth: []ssh.AuthMethod{ssh.Password("admin")}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 2 * time.Second, + } + for time.Now().Before(deadline) { + sc, err := ssh.Dial("tcp", addr, cfg) + if err == nil { + return sc, nil + } + time.Sleep(time.Second) + } + return nil, fmt.Errorf("SSH to %s timed out after %v", addr, timeout) +} + +// ensureTailMac locates the pre-built tailmac Host.app binary. +func (e *Env) ensureTailMac() error { + modRoot, err := findModRoot() + if err != nil { + return err + } + e.tailmacDir = filepath.Join(modRoot, "tstest", "tailmac", "bin") + hostApp := filepath.Join(e.tailmacDir, "Host.app", "Contents", "MacOS", "Host") + if _, err := os.Stat(hostApp); err != nil { + return fmt.Errorf("tailmac Host.app not found at %s; run 'make all' in tstest/tailmac/", hostApp) + } + return nil +} + +// cloneTartToTailmac creates a tailmac-compatible VM directory from a Tart +// base image. It uses APFS CoW clones for the disk and NVRAM, and extracts +// the hardware identity from Tart's config.json. +func cloneTartToTailmac(tartDir, cloneDir, testID, mac, dgramSock string) error { + if err := os.MkdirAll(cloneDir, 0755); err != nil { + return err + } + + cfgData, err := os.ReadFile(filepath.Join(tartDir, "config.json")) + if err != nil { + return fmt.Errorf("reading tart config: %w", err) + } + var tc tartConfig + if err := json.Unmarshal(cfgData, &tc); err != nil { + return fmt.Errorf("parsing tart config: %w", err) + } + + hwModel, err := base64.StdEncoding.DecodeString(tc.HardwareModel) + if err != nil { + return fmt.Errorf("decoding hardwareModel: %w", err) + } + if err := os.WriteFile(filepath.Join(cloneDir, "HardwareModel"), hwModel, 0644); err != nil { + return err + } + + ecid, err := base64.StdEncoding.DecodeString(tc.ECID) + if err != nil { + return fmt.Errorf("decoding ecid: %w", err) + } + if err := os.WriteFile(filepath.Join(cloneDir, "MachineIdentifier"), ecid, 0644); err != nil { + return err + } + + if out, err := exec.Command("cp", "-c", filepath.Join(tartDir, "disk.img"), filepath.Join(cloneDir, "Disk.img")).CombinedOutput(); err != nil { + if out2, err2 := exec.Command("cp", filepath.Join(tartDir, "disk.img"), filepath.Join(cloneDir, "Disk.img")).CombinedOutput(); err2 != nil { + return fmt.Errorf("copying disk: %v: %s (APFS clone: %v: %s)", err2, out2, err, out) + } + } + + if out, err := exec.Command("cp", "-c", filepath.Join(tartDir, "nvram.bin"), filepath.Join(cloneDir, "AuxiliaryStorage")).CombinedOutput(); err != nil { + if out2, err2 := exec.Command("cp", filepath.Join(tartDir, "nvram.bin"), filepath.Join(cloneDir, "AuxiliaryStorage")).CombinedOutput(); err2 != nil { + return fmt.Errorf("copying nvram: %v: %s (APFS clone: %v: %s)", err2, out2, err, out) + } + } + + tmCfg := struct { + VMid string `json:"vmID"` + ServerSocket string `json:"serverSocket"` + MemorySize uint64 `json:"memorySize"` + Mac string `json:"mac"` + }{ + VMid: testID, + ServerSocket: dgramSock, + MemorySize: 4 * 1024 * 1024 * 1024, + Mac: mac, + } + tmData, _ := json.MarshalIndent(tmCfg, "", " ") + return os.WriteFile(filepath.Join(cloneDir, "config.json"), tmData, 0644) +} + +// startTailMacVM restores a macOS VM from a cached snapshot and launches it +// via tailmac Host.app in headless mode, connected to vnet's dgram socket. +func (e *Env) startTailMacVM(n *Node) error { + snapDir := e.macosSnapshot + + if err := e.ensureTailMac(); err != nil { + return err + } + + testID := fmt.Sprintf("vmtest-%s-%d", n.name, os.Getpid()) + + // Host.app expects VM files under ~/.cache/tailscale/vmtest/macos// + home, err := os.UserHomeDir() + if err != nil { + return fmt.Errorf("UserHomeDir: %w", err) + } + vmBase := filepath.Join(home, ".cache", "tailscale", "vmtest", "macos") + os.MkdirAll(vmBase, 0755) + cloneDir := filepath.Join(vmBase, testID) + + // APFS clone the entire snapshot directory (includes SaveFile.vzvmsave). + e.t.Logf("[%s] cloning snapshot -> %s", n.name, testID) + if out, err := exec.Command("cp", "-c", "-r", snapDir, cloneDir).CombinedOutput(); err != nil { + if out2, err2 := exec.Command("cp", "-r", snapDir, cloneDir).CombinedOutput(); err2 != nil { + return fmt.Errorf("cloning snapshot: %v: %s (APFS clone: %v: %s)", err2, out2, err, out) + } + } + e.t.Cleanup(func() { os.RemoveAll(cloneDir) }) + + // Write test-specific config.json with the vnet MAC and dgram socket. + mac := n.vnetNode.NICMac(0) + cfg := struct { + VMid string `json:"vmID"` + ServerSocket string `json:"serverSocket"` + MemorySize uint64 `json:"memorySize"` + Mac string `json:"mac"` + }{ + VMid: testID, + ServerSocket: e.dgramSockAddr, + MemorySize: 8 * 1024 * 1024 * 1024, + Mac: mac.String(), + } + cfgData, _ := json.MarshalIndent(cfg, "", " ") + if err := os.WriteFile(filepath.Join(cloneDir, "config.json"), cfgData, 0644); err != nil { + return fmt.Errorf("writing config.json: %w", err) + } + + // Launch Host.app with disconnected NIC + hot-swap to vnet. + // Host.app will restore from SaveFile.vzvmsave (fast), then + // hot-swap the NIC to the vnet dgram socket. + hostBin := filepath.Join(e.tailmacDir, "Host.app", "Contents", "MacOS", "Host") + + // Compute the node's IP and gateway for static assignment via vsock. + nodeIP := n.vnetNode.LanIP(n.nets[0]) + // The gateway is the network's base address (e.g. 192.168.1.1 for /24). + // We derive it from the node IP: same /24 prefix, host part = 1. + gwIP := nodeIP.As4() + gwIP[3] = 1 + gateway := netip.AddrFrom4(gwIP) + + args := []string{ + "run", "--id", testID, "--headless", + "--disconnected-nic", + "--attach-network", e.dgramSockAddr, + "--assign-ip", fmt.Sprintf("%s/255.255.255.0/%s", nodeIP, gateway), + } + + wantScreenshots := *vmtestWeb != "" + if wantScreenshots { + args = append(args, "--screenshot-port", "0") + } + + logPath := filepath.Join(e.tempDir, n.name+"-tailmac.log") + logFile, err := os.Create(logPath) + if err != nil { + return fmt.Errorf("creating log file: %w", err) + } + + cmd := exec.Command(hostBin, args...) + cmd.Env = append(os.Environ(), "NSUnbufferedIO=YES") + + var stdoutPipe io.ReadCloser + if wantScreenshots { + stdoutPipe, err = cmd.StdoutPipe() + if err != nil { + logFile.Close() + return fmt.Errorf("stdout pipe: %w", err) + } + cmd.Stderr = logFile + } else { + cmd.Stdout = logFile + cmd.Stderr = logFile + } + devNull, err := os.Open(os.DevNull) + if err != nil { + logFile.Close() + return fmt.Errorf("open /dev/null: %w", err) + } + cmd.Stdin = devNull + + if err := cmd.Start(); err != nil { + devNull.Close() + logFile.Close() + return fmt.Errorf("starting tailmac for %s: %w", n.name, err) + } + e.t.Logf("[%s] launched tailmac (pid %d), log: %s", n.name, cmd.Process.Pid, logPath) + + if wantScreenshots { + screenshotPortCh := make(chan int, 1) + go func() { + scanner := bufio.NewScanner(stdoutPipe) + for scanner.Scan() { + line := scanner.Text() + fmt.Fprintln(logFile, line) + if port := 0; strings.HasPrefix(line, "SCREENSHOT_PORT=") { + fmt.Sscanf(line, "SCREENSHOT_PORT=%d", &port) + if port > 0 { + screenshotPortCh <- port + } + } + } + }() + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + select { + case port := <-screenshotPortCh: + e.t.Logf("[%s] screenshot server on port %d", n.name, port) + e.setNodeScreenshotPort(n.name, port) + e.tailScreenshots(n.name, port) + case <-ctx.Done(): + e.t.Logf("[%s] screenshot port not received", n.name) + } + }() + } + + clientSock := fmt.Sprintf("/tmp/qemu-dgram-%s.sock", testID) + + e.t.Cleanup(func() { + // Kill immediately — no need to save state for ephemeral test clones. + cmd.Process.Kill() + cmd.Wait() + devNull.Close() + logFile.Close() + os.Remove(clientSock) + + if e.t.Failed() { + if data, err := os.ReadFile(logPath); err == nil { + lines := strings.Split(string(data), "\n") + start := 0 + if len(lines) > 50 { + start = len(lines) - 50 + } + e.t.Logf("=== last 50 lines of %s tailmac log ===", n.name) + for _, line := range lines[start:] { + e.t.Logf("[%s] %s", n.name, line) + } + } + } + }) + + return nil +} + +// tailScreenshots polls the Host.app screenshot HTTP server every 2 seconds +// and publishes each screenshot as a base64 data URI to the web UI. +func (e *Env) tailScreenshots(name string, port int) { + url := fmt.Sprintf("http://127.0.0.1:%d/screenshot", port) + client := &http.Client{Timeout: 5 * time.Second} + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + + for range ticker.C { + resp, err := client.Get(url) + if err != nil { + continue + } + data, _ := io.ReadAll(resp.Body) + resp.Body.Close() + if resp.StatusCode != 200 || len(data) == 0 { + continue + } + b64 := base64.StdEncoding.EncodeToString(data) + dataURI := "data:image/jpeg;base64," + b64 + e.setNodeScreenshot(name, dataURI) + e.eventBus.Publish(VMEvent{ + NodeName: name, + Type: EventScreenshot, + Message: b64, + }) + } +} diff --git a/tstest/natlab/vmtest/version.go b/tstest/natlab/vmtest/version.go new file mode 100644 index 0000000000000..7e76716e4016f --- /dev/null +++ b/tstest/natlab/vmtest/version.go @@ -0,0 +1,195 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package vmtest + +import ( + "archive/tar" + "compress/gzip" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path" + "path/filepath" + "regexp" + "strconv" + "strings" + + "tailscale.com/types/logger" +) + +// versionRE matches a concrete X.Y.Z release version. +var versionRE = regexp.MustCompile(`^\d+\.\d+\.\d+$`) + +// resolveTestVersion returns the concrete release version (e.g. "1.97.255") +// for the given --test-version flag value. If v is "unstable" or "stable", it +// queries pkgs.tailscale.com for the latest TarballsVersion on that track. +// Otherwise it returns v unchanged. +func resolveTestVersion(ctx context.Context, v string) (string, error) { + if v != "unstable" && v != "stable" { + if !versionRE.MatchString(v) { + return "", fmt.Errorf("invalid --test-version %q: want \"stable\", \"unstable\", or X.Y.Z", v) + } + return v, nil + } + url := "https://pkgs.tailscale.com/" + v + "/?mode=json" + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return "", err + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", fmt.Errorf("fetching %s: %w", url, err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + return "", fmt.Errorf("fetching %s: HTTP %s", url, resp.Status) + } + var meta struct { + TarballsVersion string + } + if err := json.NewDecoder(resp.Body).Decode(&meta); err != nil { + return "", fmt.Errorf("decoding %s: %w", url, err) + } + if meta.TarballsVersion == "" { + return "", fmt.Errorf("no TarballsVersion in %s response", url) + } + return meta.TarballsVersion, nil +} + +// versionTrack returns the pkgs.tailscale.com track ("stable" or "unstable") +// for a release version. Even minors are stable; odd minors are unstable. +func versionTrack(version string) (string, error) { + parts := strings.Split(version, ".") + if len(parts) < 2 { + return "", fmt.Errorf("bad version %q (expected like 1.97.255)", version) + } + minor, err := strconv.Atoi(parts[1]) + if err != nil { + return "", fmt.Errorf("bad minor in version %q: %w", version, err) + } + if minor%2 == 0 { + return "stable", nil + } + return "unstable", nil +} + +// versionCacheRoot returns the root cache directory for downloaded version +// tarballs. +func versionCacheRoot() string { + if d := os.Getenv("VMTEST_BUILDS_CACHE_DIR"); d != "" { + return d + } + cache, err := os.UserCacheDir() + if err != nil { + panic(fmt.Sprintf("os.UserCacheDir: %v", err)) + } + return filepath.Join(cache, "tailscale-vmtest", "builds") +} + +// versionCacheDir returns the directory holding the extracted binaries for +// the given version+arch. +func versionCacheDir(version, arch string) string { + return filepath.Join(versionCacheRoot(), fmt.Sprintf("%s_%s", version, arch)) +} + +// ensureVersionBinaries downloads (if needed) and extracts the tailscale +// release tarball for the given concrete version+arch, returning the +// directory containing tailscale and tailscaled. +func ensureVersionBinaries(ctx context.Context, version, arch string, logf logger.Logf) (string, error) { + dir := versionCacheDir(version, arch) + tailscaled := filepath.Join(dir, "tailscaled") + tailscale := filepath.Join(dir, "tailscale") + if _, err1 := os.Stat(tailscaled); err1 == nil { + if _, err2 := os.Stat(tailscale); err2 == nil { + return dir, nil + } + } + + track, err := versionTrack(version) + if err != nil { + return "", err + } + url := fmt.Sprintf("https://pkgs.tailscale.com/%s/tailscale_%s_%s.tgz", track, version, arch) + logf("downloading %s", url) + + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return "", err + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", fmt.Errorf("fetching %s: %w", url, err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + return "", fmt.Errorf("fetching %s: HTTP %s", url, resp.Status) + } + + if err := os.MkdirAll(dir, 0755); err != nil { + return "", err + } + + gzr, err := gzip.NewReader(resp.Body) + if err != nil { + return "", fmt.Errorf("gzip reader for %s: %w", url, err) + } + defer gzr.Close() + tr := tar.NewReader(gzr) + + wantBase := map[string]bool{ + "tailscale": true, + "tailscaled": true, + } + got := map[string]bool{} + for { + hdr, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + return "", fmt.Errorf("reading tarball %s: %w", url, err) + } + if hdr.Typeflag != tar.TypeReg { + continue + } + base := path.Base(hdr.Name) + if !wantBase[base] { + continue + } + if err := writeAtomic(filepath.Join(dir, base), tr, 0755); err != nil { + return "", fmt.Errorf("extracting %s from %s: %w", base, url, err) + } + got[base] = true + } + for b := range wantBase { + if !got[b] { + return "", fmt.Errorf("tarball %s missing %s", url, b) + } + } + logf("extracted %s and %s to %s", "tailscale", "tailscaled", dir) + return dir, nil +} + +// writeAtomic writes the contents of r to dst with the given permission +// bits, by writing to a sibling temp file and renaming on success. +func writeAtomic(dst string, r io.Reader, perm os.FileMode) error { + tmp := dst + ".tmp" + f, err := os.OpenFile(tmp, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, perm) + if err != nil { + return err + } + if _, err := io.Copy(f, r); err != nil { + f.Close() + os.Remove(tmp) + return err + } + if err := f.Close(); err != nil { + os.Remove(tmp) + return err + } + return os.Rename(tmp, dst) +} diff --git a/tstest/natlab/vmtest/version_test.go b/tstest/natlab/vmtest/version_test.go new file mode 100644 index 0000000000000..3750562905fe7 --- /dev/null +++ b/tstest/natlab/vmtest/version_test.go @@ -0,0 +1,97 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package vmtest + +import ( + "context" + "flag" + "os" + "path/filepath" + "testing" +) + +var testDownloadVersion = flag.Bool("test-download-version", false, "in TestVersionDownload, actually hit pkgs.tailscale.com") + +func TestResolveTestVersionInvalid(t *testing.T) { + bad := []string{ + "", + "1.97", + "v1.97.255", + "1.97.255-pre", + "latest", + "unstabel", + } + for _, v := range bad { + got, err := resolveTestVersion(context.Background(), v) + if err == nil { + t.Errorf("resolveTestVersion(%q) = %q, want error", v, got) + } + } +} + +func TestVersionTrack(t *testing.T) { + cases := []struct { + v, want string + }{ + {"1.96.4", "stable"}, + {"1.97.255", "unstable"}, + {"1.98.0", "stable"}, + } + for _, c := range cases { + got, err := versionTrack(c.v) + if err != nil { + t.Errorf("versionTrack(%q): %v", c.v, err) + continue + } + if got != c.want { + t.Errorf("versionTrack(%q) = %q, want %q", c.v, got, c.want) + } + } +} + +// TestVersionDownload exercises the live network path (download + extract + +// cache). Skipped by default; set --test-download-version to run. +func TestVersionDownload(t *testing.T) { + if !*testDownloadVersion { + t.Skip("set --test-download-version to run") + } + cacheRoot := t.TempDir() + t.Setenv("VMTEST_BUILDS_CACHE_DIR", cacheRoot) + + ctx := context.Background() + const version = "1.96.4" // stable + dir, err := ensureVersionBinaries(ctx, version, "amd64", t.Logf) + if err != nil { + t.Fatal(err) + } + wantDir := filepath.Join(cacheRoot, version+"_amd64") + if dir != wantDir { + t.Errorf("dir = %q, want %q", dir, wantDir) + } + for _, name := range []string{"tailscale", "tailscaled"} { + fi, err := os.Stat(filepath.Join(dir, name)) + if err != nil { + t.Errorf("missing %s: %v", name, err) + continue + } + if fi.Size() < 1<<20 { + t.Errorf("%s suspiciously small: %d bytes", name, fi.Size()) + } + } + + // Re-fetch should be a fast no-op (cache hit). + if _, err := ensureVersionBinaries(ctx, version, "amd64", t.Logf); err != nil { + t.Fatalf("re-fetch: %v", err) + } + + // "unstable" resolution. + resolved, err := resolveTestVersion(ctx, "unstable") + if err != nil { + t.Fatalf("resolveTestVersion(unstable): %v", err) + } + t.Logf("unstable resolved to %q", resolved) + if resolved == "" || resolved == "unstable" { + t.Errorf("resolved = %q", resolved) + } +} diff --git a/tstest/natlab/vmtest/vmstatus.go b/tstest/natlab/vmtest/vmstatus.go new file mode 100644 index 0000000000000..38269a78061c0 --- /dev/null +++ b/tstest/natlab/vmtest/vmstatus.go @@ -0,0 +1,327 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package vmtest + +import ( + "fmt" + "sync" + "time" +) + +// StepStatus is the state of a declared test step. +type StepStatus int + +const ( + StepPending StepStatus = iota // not yet started + StepRunning // Begin called + StepDone // End(nil) called + StepFailed // End(non-nil) called +) + +func (s StepStatus) String() string { + switch s { + case StepPending: + return "pending" + case StepRunning: + return "running" + case StepDone: + return "done" + case StepFailed: + return "failed" + } + return fmt.Sprintf("StepStatus(%d)", int(s)) +} + +// Icon returns a Unicode icon for the step status. +func (s StepStatus) Icon() string { + switch s { + case StepPending: + return "○" + case StepRunning: + return "◉" + case StepDone: + return "✓" + case StepFailed: + return "✗" + } + return "?" +} + +// Step is a declared stage of a test, created by [Env.AddStep]. +// The web UI shows all steps from the start, tracking their progress. +type Step struct { + mu sync.Mutex + name string + index int // 0-based position in Env.steps + env *Env + status StepStatus + err error + started time.Time + ended time.Time +} + +// Name returns the step's display name. +func (s *Step) Name() string { return s.name } + +// Index returns the step's 0-based position. +func (s *Step) Index() int { return s.index } + +// Status returns the current status. +func (s *Step) Status() StepStatus { + s.mu.Lock() + defer s.mu.Unlock() + return s.status +} + +// Err returns the error if the step failed, or nil. +func (s *Step) Err() error { + s.mu.Lock() + defer s.mu.Unlock() + return s.err +} + +// Elapsed returns how long the step has been running (if running) +// or how long it took (if done/failed). Returns 0 if pending. +func (s *Step) Elapsed() time.Duration { + s.mu.Lock() + defer s.mu.Unlock() + if s.started.IsZero() { + return 0 + } + if !s.ended.IsZero() { + return s.ended.Sub(s.started) + } + return time.Since(s.started) +} + +// Begin marks the step as running. Publishes an event to the web UI. +func (s *Step) Begin() { + s.mu.Lock() + if s.status != StepPending { + s.mu.Unlock() + panic(fmt.Sprintf("Step %q: Begin called in state %s", s.name, s.status)) + } + s.started = time.Now() + s.status = StepRunning + s.mu.Unlock() + s.env.publishStepChange(s) +} + +// End marks the step as done (err == nil) or failed (err != nil). +// It publishes a status change event to the web UI. +// It does not call t.Fatalf; callers should handle the error as appropriate +// (return it from errgroup, call t.Fatalf on the test goroutine, etc). +func (s *Step) End(err error) { + s.mu.Lock() + if s.status != StepRunning { + s.mu.Unlock() + panic(fmt.Sprintf("Step %q: End called in state %s", s.name, s.status)) + } + s.ended = time.Now() + if err != nil { + s.status = StepFailed + s.err = err + } else { + s.status = StepDone + } + s.mu.Unlock() + s.env.publishStepChange(s) +} + +// EventType identifies the kind of event published to the EventBus. +type EventType string + +const ( + EventStepChanged EventType = "step_changed" // a Step changed status + EventConsoleOutput EventType = "console_output" // serial console line + EventDHCPDiscover EventType = "dhcp_discover" // VM sent DHCP Discover + EventDHCPOffer EventType = "dhcp_offer" // server sent DHCP Offer + EventDHCPRequest EventType = "dhcp_request" // VM sent DHCP Request + EventDHCPAck EventType = "dhcp_ack" // server sent DHCP Ack + EventScreenshot EventType = "screenshot" // VM display screenshot (JPEG, base64) + EventTailscale EventType = "tailscale" // Tailscale status change + EventTestStatus EventType = "test_status" // test Running/Passed/Failed +) + +// TestStatus tracks whether the overall test is running, passed, or failed. +type TestStatus struct { + mu sync.Mutex + state string // "Running", "Passed", "Failed" + started time.Time + ended time.Time +} + +func newTestStatus() *TestStatus { + return &TestStatus{state: "Running", started: time.Now()} +} + +// State returns the current test state. +func (ts *TestStatus) State() string { + ts.mu.Lock() + defer ts.mu.Unlock() + return ts.state +} + +// Elapsed returns total test duration. +func (ts *TestStatus) Elapsed() time.Duration { + ts.mu.Lock() + defer ts.mu.Unlock() + if !ts.ended.IsZero() { + return ts.ended.Sub(ts.started) + } + return time.Since(ts.started) +} + +// StartUnixMilli returns the test start time as Unix milliseconds, +// for the client-side elapsed timer. +func (ts *TestStatus) StartUnixMilli() int64 { + ts.mu.Lock() + defer ts.mu.Unlock() + return ts.started.UnixMilli() +} + +// finish marks the test as passed or failed. +func (ts *TestStatus) finish(failed bool) { + ts.mu.Lock() + defer ts.mu.Unlock() + ts.ended = time.Now() + if failed { + ts.state = "Failed" + } else { + ts.state = "Passed" + } +} + +// VMEvent is a single event published to the [EventBus]. +type VMEvent struct { + Time time.Time + NodeName string // "" for global events + Type EventType + Message string // human-readable description + Detail string // e.g. IP address, node key + Step *Step // non-nil for EventStepChanged + NIC int // NIC index for DHCP events (0-based); -1 if not applicable +} + +// NICStatus is the DHCP state for one NIC on a node. +type NICStatus struct { + NetName string // human label like "192.168.1.0/24" or "10.0.0.0/24" + DHCP string // "waiting", "Discover sent", "Got 10.0.0.101", etc. +} + +// NodeStatus tracks the current DHCP and Tailscale state of a VM node +// for rendering on the web UI's initial page load. +type NodeStatus struct { + Name string + OS string + NICs []NICStatus // one per NIC; index matches NIC index + JoinsTailnet bool // whether this node runs Tailscale + Tailscale string // "--", "Up (100.64.0.1)", etc. + Console []string // recent console output lines (ring buffer) + Screenshot string // latest screenshot as data URI, or "" + ScreenshotPort int // Host.app screenshot server port, or 0 +} + +const maxConsoleLines = 200 + +const ( + eventBusHistorySize = 500 + subscriberChannelSize = 1000 +) + +// EventBus broadcasts VMEvents to subscribers and keeps a history for +// late joiners. It is safe for concurrent use. +type EventBus struct { + mu sync.Mutex + history []VMEvent + subscribers map[*subscriber]struct{} +} + +func newEventBus() *EventBus { + return &EventBus{ + subscribers: make(map[*subscriber]struct{}), + } +} + +// Publish sends an event to all subscribers and appends it to the history. +// Non-blocking: slow subscribers are skipped. +func (b *EventBus) Publish(ev VMEvent) { + if ev.Time.IsZero() { + ev.Time = time.Now() + } + b.mu.Lock() + defer b.mu.Unlock() + // Don't store screenshots in history — they're large and only the + // latest one matters (stored in NodeStatus.Screenshot instead). + if ev.Type != EventScreenshot { + b.history = append(b.history, ev) + } + if len(b.history) > eventBusHistorySize { + // Trim old events. + copy(b.history, b.history[len(b.history)-eventBusHistorySize:]) + b.history = b.history[:eventBusHistorySize] + } + for sub := range b.subscribers { + select { + case sub.ch <- ev: + default: + // Slow consumer, skip. + } + } +} + +// Subscribe returns a new subscriber that receives the event history +// followed by live events. +func (b *EventBus) Subscribe() *subscriber { + b.mu.Lock() + defer b.mu.Unlock() + sub := &subscriber{ + bus: b, + ch: make(chan VMEvent, subscriberChannelSize), + done: make(chan struct{}), + } + // Send history. + for _, ev := range b.history { + select { + case sub.ch <- ev: + default: + } + } + b.subscribers[sub] = struct{}{} + return sub +} + +func (b *EventBus) unsubscribe(sub *subscriber) { + b.mu.Lock() + defer b.mu.Unlock() + delete(b.subscribers, sub) +} + +// subscriber receives events from an [EventBus]. +type subscriber struct { + bus *EventBus + ch chan VMEvent + done chan struct{} + once sync.Once +} + +// Events returns the channel of events. Closed when Close is called. +func (s *subscriber) Events() <-chan VMEvent { + return s.ch +} + +// Close unsubscribes and closes the event channel. +func (s *subscriber) Close() { + s.once.Do(func() { + if s.bus != nil { + s.bus.unsubscribe(s) + } + close(s.done) + }) +} + +// Done returns a channel that's closed when Close is called. +func (s *subscriber) Done() <-chan struct{} { + return s.done +} diff --git a/tstest/natlab/vmtest/vmtest.go b/tstest/natlab/vmtest/vmtest.go new file mode 100644 index 0000000000000..9b029a11953c6 --- /dev/null +++ b/tstest/natlab/vmtest/vmtest.go @@ -0,0 +1,1671 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// Package vmtest provides a high-level framework for running integration tests +// across multiple QEMU virtual machines connected by natlab's vnet virtual +// network infrastructure. It supports mixed OS types (gokrazy, Ubuntu, Debian) +// and multi-NIC configurations for scenarios like subnet routing. +// +// Prerequisites: +// - qemu-system-x86_64 and KVM access (typically the "kvm" group; no root required) +// - A built gokrazy natlabapp image (auto-built on first run via "make natlab" in gokrazy/) +// +// Run tests with: +// +// go test ./tstest/natlab/vmtest/ --run-vm-tests -v +package vmtest + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "flag" + "fmt" + "io" + "net" + "net/http" + "net/netip" + "net/url" + "os" + "os/exec" + "path/filepath" + "runtime" + "strconv" + "strings" + "sync" + "testing" + "time" + + "github.com/google/gopacket/layers" + "go4.org/mem" + "golang.org/x/sync/errgroup" + "tailscale.com/client/local" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnstate" + "tailscale.com/tailcfg" + "tailscale.com/tstest" + "tailscale.com/tstest/integration/testcontrol" + "tailscale.com/tstest/natlab/vnet" + "tailscale.com/types/key" + "tailscale.com/util/mak" +) + +var ( + runVMTests = flag.Bool("run-vm-tests", false, "run tests that require VMs with KVM") + verboseVMDebug = flag.Bool("verbose-vm-debug", false, "enable verbose debug logging for VM tests") + testVersion = flag.String("test-version", "", `if non-empty, download tailscale & tailscaled at the given release version (e.g. "1.97.255", "unstable", or "stable") instead of building from the source tree`) +) + +// Env is a test environment that manages virtual networks and QEMU VMs. +// Create one with New, add networks and nodes, then call Start. +type Env struct { + t testing.TB + cfg vnet.Config + server *vnet.Server + nodes []*Node + tempDir string + + sockAddr string // shared Unix socket path for all QEMU netdevs + dgramSockAddr string // Unix dgram socket path for macOS VMs (tailmac) + binDir string // directory for compiled binaries + + // testVersion is the resolved Tailscale release version to use (empty if + // building from source). When non-empty, tailscale and tailscaled binaries + // are downloaded from pkgs.tailscale.com instead of compiled from the tree. + testVersion string + + // gokrazy-specific paths + gokrazyBase string // path to gokrazy base qcow2 image + gokrazyKernel string // path to gokrazy kernel + + // tailmac-specific paths (macOS VMs) + tailmacDir string // path to tailmac bin/ directory containing Host.app + macosSnapshot string // path to cached macOS VM snapshot directory + macosSnapshotOnce sync.Once + + qemuProcs []*exec.Cmd // launched QEMU processes + + sameTailnetUser bool // all nodes register as the same Tailnet user + allOnline bool // mark every peer as Online=true in MapResponses + + // Shared resource initialization (sync.Once for things multiple nodes share). + vnetOnce sync.Once + gokrazyOnce sync.Once + qemuSockOnce sync.Once + dgramSockOnce sync.Once + compileMu sync.Mutex + compileOnce map[string]*sync.Once // keyed by goos_goarch + imageOnce map[string]*sync.Once // keyed by OSImage.Name + + // Web UI support. + ctx context.Context // cancelled when test ends + eventBus *EventBus + testStatus *TestStatus + stepsMu sync.Mutex + stepsByKey map[string]*Step + steps []*Step + + nodeStatusMu sync.Mutex + nodeStatus map[string]*NodeStatus // keyed by node name +} + +// logVerbosef logs a message only when --verbose-vm-debug is set. +func (e *Env) logVerbosef(format string, args ...any) { + if *verboseVMDebug { + e.t.Helper() + e.t.Logf(format, args...) + } +} + +// vmPlatform defines how a VM type boots. Each OS image type (gokrazy, +// cloud, macOS) implements this interface. +type vmPlatform interface { + // planSteps registers steps with the web UI in a dry-run pass. + planSteps(e *Env, n *Node) + + // boot does everything needed to get this node running: ensure images, + // compile binaries, set up sockets, launch VM. Called concurrently. + boot(ctx context.Context, e *Env, n *Node) error +} + +// platform returns the vmPlatform for this node's OS type. +func (n *Node) platform() vmPlatform { + if n.os.IsMacOS { + return macPlatform{} + } + if n.os.IsGokrazy { + return gokrazyPlatform{} + } + return qemuCloudPlatform{} +} + +// AddStep declares an expected stage of the test. The web UI shows all steps +// from the start, tracking their progress. Call before or during the test. +// Returns a *Step whose Begin/End methods drive the progress display. +func (e *Env) AddStep(name string) *Step { + s := &Step{ + name: name, + index: len(e.steps), + env: e, + } + e.steps = append(e.steps, s) + return s +} + +// Step returns a step by key, creating it if it doesn't exist. +// Safe for concurrent use. Both planSteps (dry-run) and boot (real-run) +// call this to get the same Step object. +func (e *Env) Step(key string) *Step { + e.stepsMu.Lock() + defer e.stepsMu.Unlock() + if s, ok := e.stepsByKey[key]; ok { + return s + } + s := &Step{ + name: key, + index: len(e.steps), + env: e, + } + e.steps = append(e.steps, s) + if e.stepsByKey == nil { + e.stepsByKey = make(map[string]*Step) + } + e.stepsByKey[key] = s + return s +} + +// Steps returns all declared steps in order. +func (e *Env) Steps() []*Step { + return e.steps +} + +// publishStepChange publishes a step status change event. +func (e *Env) publishStepChange(s *Step) { + e.eventBus.Publish(VMEvent{ + Type: EventStepChanged, + Message: fmt.Sprintf("%s %s", s.Status().Icon(), s.name), + Step: s, + }) +} + +// initNodeStatus initializes the NodeStatus for all nodes. Called after +// AddNode but before Start so the web UI can render them. +func (e *Env) initNodeStatus() { + e.nodeStatusMu.Lock() + defer e.nodeStatusMu.Unlock() + for _, n := range e.nodes { + nics := make([]NICStatus, len(n.nets)) + for i := range n.nets { + nics[i] = NICStatus{ + NetName: e.nicLabel(n, i), + DHCP: "waiting", + } + } + e.nodeStatus[n.name] = &NodeStatus{ + Name: n.name, + OS: n.os.Name, + NICs: nics, + JoinsTailnet: n.joinTailnet, + Tailscale: "--", + } + } +} + +// nicLabel returns a short human-readable label for a node's i-th NIC. +// After Start(), we can use the assigned LAN IP. Before that, we use "NIC N". +func (e *Env) nicLabel(n *Node, i int) string { + if n.vnetNode != nil { + ip := n.vnetNode.LanIP(n.nets[i]) + if ip.IsValid() { + return ip.String() + } + } + return fmt.Sprintf("NIC %d", i) +} + +// getNodeStatus returns the current status for a node. +func (e *Env) getNodeStatus(name string) NodeStatus { + e.nodeStatusMu.Lock() + defer e.nodeStatusMu.Unlock() + ns := e.nodeStatus[name] + if ns == nil { + return NodeStatus{Name: name, Tailscale: "--"} + } + return *ns +} + +// setNodeDHCP updates the DHCP status for a specific NIC on a node. +func (e *Env) setNodeDHCP(name string, nicIdx int, status string) { + e.nodeStatusMu.Lock() + ns := e.nodeStatus[name] + if ns != nil && nicIdx < len(ns.NICs) { + ns.NICs[nicIdx].DHCP = status + } + e.nodeStatusMu.Unlock() +} + +// setNodeTailscale updates the Tailscale status for a node and publishes +// an event so the web UI updates via WebSocket. +func (e *Env) setNodeTailscale(name, status string) { + e.nodeStatusMu.Lock() + ns := e.nodeStatus[name] + if ns != nil { + ns.Tailscale = status + } + e.nodeStatusMu.Unlock() + e.eventBus.Publish(VMEvent{ + NodeName: name, + Type: EventTailscale, + Message: "Tailscale: " + status, + Detail: status, + }) +} + +// appendConsoleLine adds a line to a node's console buffer. +func (e *Env) appendConsoleLine(name, line string) { + e.nodeStatusMu.Lock() + ns := e.nodeStatus[name] + if ns != nil { + ns.Console = append(ns.Console, line) + if len(ns.Console) > maxConsoleLines { + ns.Console = ns.Console[len(ns.Console)-maxConsoleLines:] + } + } + e.nodeStatusMu.Unlock() +} + +// nicIndexForMAC returns the NIC index (0-based) for a given MAC on a node. +// Returns -1 if not found. +func (e *Env) nicIndexForMAC(name string, mac vnet.MAC) int { + for _, n := range e.nodes { + if n.name != name { + continue + } + for i := range n.nets { + if n.vnetNode.NICMac(i) == mac { + return i + } + } + } + return -1 +} + +// nodeNameByNum returns the node name for a given vnet node number. +func (e *Env) nodeNameByNum(num int) string { + for _, n := range e.nodes { + if n.num == num { + return n.name + } + } + return fmt.Sprintf("node%d", num) +} + +// New creates a new test environment. It skips the test if --run-vm-tests is +// not set. opts may contain [EnvOption] values returned by helpers like +// [SameTailnetUser]. +func New(t testing.TB, opts ...EnvOption) *Env { + if !*runVMTests { + t.Skip("skipping VM test; set --run-vm-tests to run") + } + + tempDir := t.TempDir() + e := &Env{ + t: t, + tempDir: tempDir, + binDir: filepath.Join(tempDir, "bin"), + eventBus: newEventBus(), + testStatus: newTestStatus(), + nodeStatus: make(map[string]*NodeStatus), + } + for _, o := range opts { + o.applyTo(e) + } + t.Cleanup(func() { + e.testStatus.finish(t.Failed()) + e.eventBus.Publish(VMEvent{ + Type: EventTestStatus, + Message: e.testStatus.State(), + Detail: formatDuration(e.testStatus.Elapsed()), + }) + }) + return e +} + +// EnvOption configures an [Env] in [New]. +type EnvOption interface { + applyTo(*Env) +} + +type envOptFunc func(*Env) + +func (f envOptFunc) applyTo(e *Env) { f(e) } + +// SameTailnetUser returns an [EnvOption] that makes every node register with +// the test control server as the same Tailnet user. This is needed for +// cross-node features that require a same-user relationship — Taildrop, for +// example. +func SameTailnetUser() EnvOption { + return envOptFunc(func(e *Env) { e.sameTailnetUser = true }) +} + +// AllOnline returns an [EnvOption] that makes the test control server mark +// every peer as Online=true in MapResponses (testcontrol.Server.AllOnline). +// Several disco-key handling fast paths in the controlclient and wgengine +// only fire when the peer is reported online; without this option those +// paths are silently skipped, which can mask bugs and slow down recovery +// from disco-key rotations. +func AllOnline() EnvOption { + return envOptFunc(func(e *Env) { e.allOnline = true }) +} + +// AddNetwork creates a new virtual network. Arguments follow the same pattern as +// vnet.Config.AddNetwork (string IPs, NAT types, NetworkService values). +func (e *Env) AddNetwork(opts ...any) *vnet.Network { + return e.cfg.AddNetwork(opts...) +} + +// Node represents a virtual machine in the test environment. +type Node struct { + name string + num int // assigned during AddNode + + os OSImage + nets []*vnet.Network + vnetNode *vnet.Node // primary vnet node (set during Start) + agent *vnet.NodeAgentClient + joinTailnet bool + noAgent bool // true to skip TTA agent setup (e.g. macOS VMs without TTA) + advertiseRoutes string + snatSubnetRoutes *bool // nil means default (true) + webServerPort int + sshPort int // host port for SSH debug access (cloud VMs only) +} + +// AddNode creates a new VM node. The name is used for identification and as the +// webserver greeting. Options can be *vnet.Network (for network attachment), +// NodeOption values, or vnet node options (like vnet.TailscaledEnv). +func (e *Env) AddNode(name string, opts ...any) *Node { + n := &Node{ + name: name, + os: Gokrazy, // default + joinTailnet: true, + } + e.nodes = append(e.nodes, n) + + // Separate network options from other options. + var vnetOpts []any + for _, o := range opts { + switch o := o.(type) { + case *vnet.Network: + n.nets = append(n.nets, o) + vnetOpts = append(vnetOpts, o) + case nodeOptOS: + n.os = OSImage(o) + case nodeOptNoTailscale: + n.joinTailnet = false + vnetOpts = append(vnetOpts, vnet.DontJoinTailnet) + case nodeOptNoAgent: + n.noAgent = true + case nodeOptAdvertiseRoutes: + n.advertiseRoutes = string(o) + case nodeOptSNATSubnetRoutes: + v := bool(o) + n.snatSubnetRoutes = &v + case nodeOptWebServer: + n.webServerPort = int(o) + default: + // Pass through to vnet (TailscaledEnv, NodeOption, MAC, etc.) + vnetOpts = append(vnetOpts, o) + } + } + + n.vnetNode = e.cfg.AddNode(vnetOpts...) + n.num = n.vnetNode.Num() + return n +} + +// LanIP returns the LAN IPv4 address of this node on the given network. +// This is only valid after Env.Start() has been called. +// Name returns the node's name as set in [Env.AddNode]. +func (n *Node) Name() string { + return n.name +} + +func (n *Node) LanIP(net *vnet.Network) netip.Addr { + return n.vnetNode.LanIP(net) +} + +// NodeOption types for configuring nodes. + +type nodeOptOS OSImage +type nodeOptNoTailscale struct{} +type nodeOptNoAgent struct{} +type nodeOptAdvertiseRoutes string +type nodeOptSNATSubnetRoutes bool +type nodeOptWebServer int + +// OS returns a NodeOption that sets the node's operating system image. +func OS(img OSImage) nodeOptOS { return nodeOptOS(img) } + +// DontJoinTailnet returns a NodeOption that prevents the node from running tailscale up. +func DontJoinTailnet() nodeOptNoTailscale { return nodeOptNoTailscale{} } + +// NoAgent returns a NodeOption that skips TTA agent setup. The node will not +// have a test agent, so agent-dependent operations (Status, ExecOnNode, etc.) +// won't work. Useful for VMs that just need to boot and respond to ICMP. +func NoAgent() nodeOptNoAgent { return nodeOptNoAgent{} } + +// AdvertiseRoutes returns a NodeOption that configures the node to advertise +// the given routes (comma-separated CIDRs) when joining the tailnet. +func AdvertiseRoutes(routes string) nodeOptAdvertiseRoutes { + return nodeOptAdvertiseRoutes(routes) +} + +// SNATSubnetRoutes returns a NodeOption that sets whether the node should +// source NAT traffic to advertised subnet routes. The default is true. +// Setting this to false preserves original source IPs, which is needed +// for site-to-site configurations. +func SNATSubnetRoutes(v bool) nodeOptSNATSubnetRoutes { return nodeOptSNATSubnetRoutes(v) } + +// WebServer returns a NodeOption that starts a webserver on the given port. +// The webserver responds with "Hello world I am from " on all requests. +func WebServer(port int) nodeOptWebServer { return nodeOptWebServer(port) } + +// Start initializes the virtual network, boots all VMs in parallel, and waits +// for all TTA agents to connect. It should be called after all AddNetwork/AddNode calls. +func (e *Env) Start() { + t := e.t + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + t.Cleanup(cancel) + e.ctx = ctx + + e.initNodeStatus() + e.maybeStartWebServer() + + if err := os.MkdirAll(e.binDir, 0755); err != nil { + t.Fatal(err) + } + + if *testVersion != "" { + v, err := resolveTestVersion(ctx, *testVersion) + if err != nil { + t.Fatalf("resolving --test-version=%q: %v", *testVersion, err) + } + e.testVersion = v + t.Logf("using Tailscale release version %s (from --test-version=%q)", v, *testVersion) + } + + for _, n := range e.nodes { + if n.os.IsMacOS && (runtime.GOOS != "darwin" || runtime.GOARCH != "arm64") { + t.Skip("macOS VM tests require macOS arm64 host") + } + } + + // Dry-run: let each platform register its steps with the web UI. + userSteps := e.steps + e.steps = nil + for _, n := range e.nodes { + n.platform().planSteps(e, n) + } + for _, n := range e.nodes { + if !n.noAgent { + e.Step("Wait for agent: " + n.name) + } + if n.joinTailnet { + e.Step("Tailscale up: " + n.name) + } + } + for _, s := range userSteps { + s.index = len(e.steps) + e.steps = append(e.steps, s) + } + + // Boot all nodes in parallel. Each platform handles its own + // dependencies (image prep, binary compilation, socket setup) + // via sync.Once, so independent work overlaps naturally. + var bootEg errgroup.Group + for _, n := range e.nodes { + bootEg.Go(func() error { + return n.platform().boot(ctx, e, n) + }) + } + if err := bootEg.Wait(); err != nil { + t.Fatalf("boot: %v", err) + } + + // Set up agent clients and wait for all agents to connect. + for _, n := range e.nodes { + if n.noAgent { + continue + } + e.initVnet() // ensure vnet is ready for agent clients + n.agent = e.server.NodeAgentClient(n.vnetNode) + n.vnetNode.SetClient(n.agent) + } + + var agentEg errgroup.Group + for _, n := range e.nodes { + if n.noAgent { + continue + } + agentEg.Go(func() error { + aStep := e.Step("Wait for agent: " + n.name) + aStep.Begin() + t.Logf("[%s] waiting for agent...", n.name) + if n.joinTailnet { + st, err := n.agent.Status(ctx) + if err != nil { + return fmt.Errorf("[%s] agent status: %w", n.name, err) + } + t.Logf("[%s] agent connected, backend state: %s", n.name, st.BackendState) + } else { + if err := e.waitForAgentConn(ctx, n); err != nil { + return fmt.Errorf("[%s] agent connect: %w", n.name, err) + } + t.Logf("[%s] agent connected (no tailscale)", n.name) + } + aStep.End(nil) + + if n.vnetNode.HostFirewall() { + if err := n.agent.EnableHostFirewall(ctx); err != nil { + return fmt.Errorf("[%s] enable firewall: %w", n.name, err) + } + } + + if n.joinTailnet { + tsStep := e.Step("Tailscale up: " + n.name) + tsStep.Begin() + if err := e.tailscaleUp(ctx, n); err != nil { + return fmt.Errorf("[%s] tailscale up: %w", n.name, err) + } + st2, err := n.agent.Status(ctx) + if err != nil { + return fmt.Errorf("[%s] status after up: %w", n.name, err) + } + if st2.BackendState != "Running" { + return fmt.Errorf("[%s] state = %q, want Running", n.name, st2.BackendState) + } + + // Apply any capabilities for the node to the map. + // SetNodeCapMap pushes an updated map response immediately, then wait + // until the node reports the capability in its status. + if cm := n.vnetNode.WantCapMap(); cm != nil { + e.server.ControlServer().SetNodeCapMap(st2.Self.PublicKey, cm) + if err := tstest.WaitFor(15*time.Second, func() error { + st, err := n.agent.Status(ctx) + if err != nil { + return err + } + if st.Self == nil { + return fmt.Errorf("self is nil") + } + for c := range cm { + if !st.Self.HasCap(c) { + return fmt.Errorf("cap %v not yet received", c) + } + } + return nil + }); err != nil { + return fmt.Errorf("[%s] waiting for capabilities: %w", n.name, err) + } + } + + ips := fmt.Sprintf("%v", st2.Self.TailscaleIPs) + e.setNodeTailscale(n.name, "Running "+ips) + t.Logf("[%s] up with %v", n.name, st2.Self.TailscaleIPs) + tsStep.End(nil) + } + + return nil + }) + } + if err := agentEg.Wait(); err != nil { + t.Fatal(err) + } + + // Start webservers. + for _, n := range e.nodes { + if n.webServerPort > 0 { + if err := e.startWebServer(ctx, n); err != nil { + t.Fatalf("startWebServer(%s): %v", n.name, err) + } + } + } +} + +// tailscaleUp runs "tailscale up" on the node via TTA. +func (e *Env) tailscaleUp(ctx context.Context, n *Node) error { + url := "http://unused/up?accept-routes=true" + if n.advertiseRoutes != "" { + url += "&advertise-routes=" + n.advertiseRoutes + } + if n.snatSubnetRoutes != nil { + if *n.snatSubnetRoutes { + url += "&snat-subnet-routes=true" + } else { + url += "&snat-subnet-routes=false" + } + } + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return err + } + res, err := n.agent.HTTPClient.Do(req) + if err != nil { + return err + } + defer res.Body.Close() + body, _ := io.ReadAll(res.Body) + if res.StatusCode != 200 { + return fmt.Errorf("tailscale up: %s: %s", res.Status, body) + } + return nil +} + +// startWebServer tells TTA on the node to start a webserver. +func (e *Env) startWebServer(ctx context.Context, n *Node) error { + url := fmt.Sprintf("http://unused/start-webserver?port=%d&name=%s", n.webServerPort, n.name) + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return err + } + res, err := n.agent.HTTPClient.Do(req) + if err != nil { + return err + } + defer res.Body.Close() + if res.StatusCode != 200 { + body, _ := io.ReadAll(res.Body) + return fmt.Errorf("start-webserver: %s: %s", res.Status, body) + } + e.t.Logf("[%s] webserver started on port %d", n.name, n.webServerPort) + return nil +} + +// SetExitNode sets the client node's exit node to use for internet traffic. +// If exitNode is nil, the client's exit node is cleared (i.e., turned off). +// Otherwise exitNode must be a tailnet node with an approved 0.0.0.0/0 (and +// ::/0) route, typically configured via [AdvertiseRoutes] and +// [Env.ApproveRoutes]. +func (e *Env) SetExitNode(client, exitNode *Node) { + e.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + var ip netip.Addr + if exitNode != nil { + st, err := exitNode.agent.Status(ctx) + if err != nil { + e.t.Fatalf("SetExitNode: status for %s: %v", exitNode.name, err) + } + if len(st.Self.TailscaleIPs) == 0 { + e.t.Fatalf("SetExitNode: %s has no Tailscale IPs", exitNode.name) + } + ip = st.Self.TailscaleIPs[0] + } + + if _, err := client.agent.EditPrefs(ctx, &ipn.MaskedPrefs{ + Prefs: ipn.Prefs{ + ExitNodeID: "", + ExitNodeIP: ip, + }, + ExitNodeIDSet: true, + ExitNodeIPSet: true, + }); err != nil { + e.t.Fatalf("SetExitNode(%s -> %v): %v", client.name, exitNode, err) + } + if exitNode == nil { + e.t.Logf("[%s] cleared exit node", client.name) + } else { + e.t.Logf("[%s] using exit node %s (%v)", client.name, exitNode.name, ip) + } +} + +// SetExitNodeIP sets the client's ExitNodeIP preference directly, by IP. +// This is the right helper for plain-WireGuard exit nodes (Mullvad-style) +// that aren't on the tailnet — pass an invalid netip.Addr{} to clear. +// For tailnet exit nodes whose Tailscale IP is discoverable via TTA, use +// [Env.SetExitNode] instead. +func (e *Env) SetExitNodeIP(client *Node, ip netip.Addr) { + e.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + if _, err := client.agent.EditPrefs(ctx, &ipn.MaskedPrefs{ + Prefs: ipn.Prefs{ + ExitNodeID: "", + ExitNodeIP: ip, + }, + ExitNodeIDSet: true, + ExitNodeIPSet: true, + }); err != nil { + e.t.Fatalf("SetExitNodeIP(%s, %v): %v", client.name, ip, err) + } + if !ip.IsValid() { + e.t.Logf("[%s] cleared exit node", client.name) + } else { + e.t.Logf("[%s] using exit-node IP %v", client.name, ip) + } +} + +// ControlServer returns the underlying test control server, for tests that +// need to inject custom peers, masquerade pairs, etc. The returned server's +// Node store is shared with the running tailnet, so changes take effect on +// the next netmap update sent to peers. +func (e *Env) ControlServer() *testcontrol.Server { + return e.server.ControlServer() +} + +// BringUpMullvadWGServer brings up a userspace WireGuard server on n, +// configured as a single-peer "Mullvad-style" exit-node target. The +// server runs inside n's TTA process on a Linux TUN named "wg0". +// +// gw is the WG interface address (e.g. 10.64.0.1/24). The server listens +// on listenPort, accepts only the single peer whose public key is peerPub +// at peerAllowedIP, and MASQUERADEs egress traffic from masqSrc so that +// decrypted packets from the peer egress with n's WAN IP. +// +// It returns the freshly generated public key of the WG server, which +// the caller must pin as the peer key on the [tailcfg.Node] it injects +// into the netmap to advertise this server as a plain-WireGuard exit +// node. It fatals the test on error. +func (e *Env) BringUpMullvadWGServer(n *Node, gw netip.Prefix, listenPort uint16, peerPub key.NodePublic, peerAllowedIP, masqSrc netip.Prefix) key.NodePublic { + e.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + peerPubRaw := peerPub.Raw32() + v := url.Values{ + "addr": {gw.String()}, + "listen-port": {strconv.Itoa(int(listenPort))}, + "peer-pub-b64": {base64.StdEncoding.EncodeToString(peerPubRaw[:])}, + "peer-allowed-ip": {peerAllowedIP.String()}, + "masq-src": {masqSrc.String()}, + } + reqURL := "http://unused/wg-server-up?" + v.Encode() + req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil) + if err != nil { + e.t.Fatalf("BringUpMullvadWGServer: %v", err) + } + res, err := n.agent.HTTPClient.Do(req) + if err != nil { + e.t.Fatalf("BringUpMullvadWGServer(%s): %v", n.name, err) + } + defer res.Body.Close() + body, _ := io.ReadAll(res.Body) + if res.StatusCode != 200 { + e.t.Fatalf("BringUpMullvadWGServer(%s): %s: %s", n.name, res.Status, body) + } + var pubB64 string + for _, line := range strings.Split(string(body), "\n") { + if s, ok := strings.CutPrefix(strings.TrimSpace(line), "PUBKEY="); ok { + pubB64 = s + break + } + } + if pubB64 == "" { + e.t.Fatalf("BringUpMullvadWGServer(%s): no PUBKEY in response: %q", n.name, body) + } + pubRaw, err := base64.StdEncoding.DecodeString(pubB64) + if err != nil || len(pubRaw) != 32 { + e.t.Fatalf("BringUpMullvadWGServer(%s): bad PUBKEY %q: %v", n.name, pubB64, err) + } + return key.NodePublicFromRaw32(mem.B(pubRaw)) +} + +// Status returns the tailscale status of the given node, fetched from its +// TTA agent. It fatals the test on error. +func (e *Env) Status(n *Node) *ipnstate.Status { + e.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + st, err := n.agent.Status(ctx) + if err != nil { + e.t.Fatalf("Status(%s): %v", n.name, err) + } + return st +} + +// SetAcceptRoutes toggles the node's RouteAll preference (the +// --accept-routes flag), controlling whether it installs subnet routes +// advertised by peers. +func (e *Env) SetAcceptRoutes(n *Node, on bool) { + e.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + if _, err := n.agent.EditPrefs(ctx, &ipn.MaskedPrefs{ + Prefs: ipn.Prefs{RouteAll: on}, + RouteAllSet: true, + }); err != nil { + e.t.Fatalf("SetAcceptRoutes(%s, %v): %v", n.name, on, err) + } + e.t.Logf("[%s] accept-routes=%v", n.name, on) +} + +// ApproveRoutes tells the test control server to approve subnet routes +// for the given node. The routes should be CIDR strings. +func (e *Env) ApproveRoutes(n *Node, routes ...string) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Get the node's public key from its status. + st, err := n.agent.Status(ctx) + if err != nil { + e.t.Fatalf("ApproveRoutes: status for %s: %v", n.name, err) + } + nodeKey := st.Self.PublicKey + + var prefixes []netip.Prefix + for _, r := range routes { + p, err := netip.ParsePrefix(r) + if err != nil { + e.t.Fatalf("ApproveRoutes: bad route %q: %v", r, err) + } + prefixes = append(prefixes, p) + } + + // Enable --accept-routes on all other tailscale nodes BEFORE setting the + // routes on the control server. This way, when the map update arrives with + // the new peer routes, peers will immediately install them. + for _, other := range e.nodes { + if other == n || !other.joinTailnet { + continue + } + if _, err := other.agent.EditPrefs(ctx, &ipn.MaskedPrefs{ + Prefs: ipn.Prefs{RouteAll: true}, + RouteAllSet: true, + }); err != nil { + e.t.Fatalf("ApproveRoutes: set accept-routes on %s: %v", other.name, err) + } + } + + // Approve the routes on the control server. SetSubnetRoutes notifies all + // peers via updatePeerChanged, so they'll re-fetch their MapResponse. + e.server.ControlServer().SetSubnetRoutes(nodeKey, prefixes) + + // Wait for each peer to see the routes. + for _, r := range routes { + for _, other := range e.nodes { + if other == n || !other.joinTailnet { + continue + } + if !e.waitForPeerRoute(other, r, 15*time.Second) { + e.DumpStatus(other) + e.t.Fatalf("ApproveRoutes: %s never saw route %s", other.name, r) + } + } + } + e.t.Logf("approved routes %v on %s", routes, n.name) + + // Ping the advertiser from each peer to establish WireGuard tunnels. + for _, other := range e.nodes { + if other == n || !other.joinTailnet { + continue + } + e.ping(other, n) + } +} + +// ping does a disco ping from one node to another's Tailscale IP, retrying +// for up to 30 seconds, fataling on failure. It is used internally to wake +// up magicsock peer state before a test runs; tests that want to assert +// connectivity should use [Env.Ping] with the appropriate ping type and +// timeout. +func (e *Env) ping(from, to *Node) { + e.t.Helper() + if err := e.Ping(from, to, tailcfg.PingDisco, 30*time.Second); err != nil { + e.t.Fatal(err) + } +} + +// Ping pings from one node to another's Tailscale IP using the given ping +// type, retrying until it succeeds or timeout expires. It returns the error +// from the last attempt if the timeout expires. Unlike the internal ping +// helper, it does not fatal the test on failure; callers can check the error +// to assert on timing. +// +// [tailcfg.PingTSMP] actually flows packets across the WireGuard tunnel and is +// the right choice for asserting end-to-end connectivity. +// [tailcfg.PingDisco] only exchanges disco messages between magicsock layers +// and is useful for warming up peer state without requiring a working tunnel. +func (e *Env) Ping(from, to *Node, ptype tailcfg.PingType, timeout time.Duration) error { + e.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + toSt, err := to.agent.Status(ctx) + if err != nil { + return fmt.Errorf("ping: can't get %s status: %w", to.name, err) + } + if len(toSt.Self.TailscaleIPs) == 0 { + return fmt.Errorf("ping: %s has no Tailscale IPs", to.name) + } + targetIP := toSt.Self.TailscaleIPs[0] + + var lastErr error + for { + // Per-attempt timeout: cap at 3s but never exceed the remaining budget. + attemptTimeout := 3 * time.Second + if d := time.Until(deadline(ctx)); d < attemptTimeout { + attemptTimeout = d + } + if attemptTimeout <= 0 { + break + } + pingCtx, pingCancel := context.WithTimeout(ctx, attemptTimeout) + pr, err := from.agent.PingWithOpts(pingCtx, targetIP, ptype, local.PingOpts{}) + pingCancel() + if err == nil && pr.Err == "" { + e.logVerbosef("ping(%s): %s -> %s OK", ptype, from.name, targetIP) + return nil + } + switch { + case err != nil: + lastErr = err + case pr.Err != "": + lastErr = fmt.Errorf("%s", pr.Err) + } + if ctx.Err() != nil { + break + } + time.Sleep(500 * time.Millisecond) + } + if lastErr == nil { + lastErr = ctx.Err() + } + return fmt.Errorf("ping(%s): %s -> %s (%s) timed out after %v: %w", ptype, from.name, to.name, targetIP, timeout, lastErr) +} + +// deadline returns ctx's deadline, or a zero Time if it has none. +func deadline(ctx context.Context) time.Time { + d, _ := ctx.Deadline() + return d +} + +// PeerDiscoKey returns n's view of the given peer's disco key. It returns a +// non-nil error if the LocalAPI request fails (e.g. tailscaled briefly +// unavailable during a restart). It returns (zero, false, nil) if n is +// reachable but has no record of the given peer in its current netmap. +// +// PeerDiscoKey is suitable for use inside a [tstest.WaitFor] poll loop: it +// does not fatal the test on transient errors. +// +// The disco key is fetched from the debug-only "peer-disco-keys" LocalAPI +// action ([ipnlocal.LocalBackend.DebugPeerDiscoKeys]) rather than via +// [ipnstate.Status], to keep the production PeerStatus struct free of disco +// keys (and free of non-comparable fields like [key.DiscoPublic] that break +// reflect-based test helpers). +func (e *Env) PeerDiscoKey(n *Node, peer key.NodePublic) (key.DiscoPublic, bool, error) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + got, err := n.agent.DebugResultJSON(ctx, "peer-disco-keys") + if err != nil { + return key.DiscoPublic{}, false, err + } + // DebugResultJSON returns the result as a generic any (the body is + // re-decoded into any), so the map comes back keyed by string text- + // encoded node keys. Re-marshal+unmarshal into a typed map for cleaner + // lookup. (Roundtripping through JSON is fine for a test helper.) + raw, err := json.Marshal(got) + if err != nil { + return key.DiscoPublic{}, false, fmt.Errorf("re-marshal: %w", err) + } + var m map[key.NodePublic]key.DiscoPublic + if err := json.Unmarshal(raw, &m); err != nil { + return key.DiscoPublic{}, false, fmt.Errorf("unmarshal peer-disco-keys: %w", err) + } + d, ok := m[peer] + return d, ok, nil +} + +// RotateDiscoKey asks tailscaled on n to rotate its discovery (magicsock) key +// in place via the LocalAPI debug action. The node key, control connection, +// and other tailscaled state are unaffected. It fatals the test on error. +func (e *Env) RotateDiscoKey(n *Node) { + e.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + if err := n.agent.DebugAction(ctx, "rotate-disco-key"); err != nil { + e.t.Fatalf("RotateDiscoKey(%s): %v", n.name, err) + } +} + +// RestartTailscaled signals tailscaled on n to die so that its supervisor +// (gokrazy) restarts it. It then waits for tailscaled to come back to the +// "Running" backend state. It fatals the test on error. +// +// Restarting tailscaled is currently only supported on gokrazy nodes. +func (e *Env) RestartTailscaled(n *Node) { + e.t.Helper() + if !n.os.IsGokrazy { + e.t.Fatalf("RestartTailscaled(%s): only supported on gokrazy nodes (have %q)", n.name, n.os.Name) + } + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "GET", "http://unused/restart-tailscaled", nil) + if err != nil { + e.t.Fatalf("RestartTailscaled(%s): %v", n.name, err) + } + res, err := n.agent.HTTPClient.Do(req) + if err != nil { + e.t.Fatalf("RestartTailscaled(%s): %v", n.name, err) + } + body, _ := io.ReadAll(res.Body) + res.Body.Close() + if res.StatusCode != 200 { + e.t.Fatalf("RestartTailscaled(%s): %s: %s", n.name, res.Status, body) + } + e.t.Logf("[%s] %s", n.name, strings.TrimSpace(string(body))) + + // Wait for tailscaled to come back. Status calls will fail while the unix + // socket is gone, then return Starting/NeedsLogin briefly before settling + // on Running. + if err := tstest.WaitFor(45*time.Second, func() error { + st, err := n.agent.Status(ctx) + if err != nil { + return err + } + if st.BackendState != "Running" { + return fmt.Errorf("backend state = %q", st.BackendState) + } + return nil + }); err != nil { + e.t.Fatalf("RestartTailscaled(%s): waiting for Running: %v", n.name, err) + } +} + +// AddRoute adds a kernel static route on the given node, pointing prefix at +// via. It uses TTA's /add-route handler, so it works on any node where TTA +// is running (which is all of them — DontJoinTailnet only skips +// `tailscale up`; the agent runs regardless). Currently Linux-only in TTA. +// +// It fatals the test on error. +func (e *Env) AddRoute(n *Node, prefix, via string) { + e.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + reqURL := fmt.Sprintf("http://unused/add-route?prefix=%s&via=%s", prefix, via) + req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil) + if err != nil { + e.t.Fatalf("AddRoute: %v", err) + } + resp, err := n.agent.HTTPClient.Do(req) + if err != nil { + e.t.Fatalf("AddRoute(%s, %s → %s): %v", n.name, prefix, via, err) + } + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != 200 { + e.t.Fatalf("AddRoute(%s, %s → %s): %s: %s", n.name, prefix, via, resp.Status, body) + } +} + +// SSHExec runs a command on a cloud VM via its debug SSH NIC. +// Only works for cloud VMs that have the debug NIC and SSH key configured. +// Returns stdout and any error. +func (e *Env) SSHExec(n *Node, cmd string) (string, error) { + if n.sshPort == 0 { + return "", fmt.Errorf("node %s has no SSH debug port", n.name) + } + sshCmd := exec.Command("ssh", + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "ConnectTimeout=5", + "-i", "/tmp/vmtest_key", + "-p", fmt.Sprintf("%d", n.sshPort), + "root@127.0.0.1", + cmd) + out, err := sshCmd.CombinedOutput() + return string(out), err +} + +// DumpStatus logs the tailscale status of a node, including its peers and their +// AllowedIPs. Useful for debugging routing issues. +func (e *Env) DumpStatus(n *Node) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + st, err := n.agent.Status(ctx) + if err != nil { + e.t.Logf("[%s] DumpStatus error: %v", n.name, err) + return + } + var selfAllowed []string + if st.Self.AllowedIPs != nil { + for i := range st.Self.AllowedIPs.Len() { + selfAllowed = append(selfAllowed, st.Self.AllowedIPs.At(i).String()) + } + } + var selfPrimary []string + if st.Self.PrimaryRoutes != nil { + for i := range st.Self.PrimaryRoutes.Len() { + selfPrimary = append(selfPrimary, st.Self.PrimaryRoutes.At(i).String()) + } + } + e.t.Logf("[%s] self: %v, backend=%s, AllowedIPs=%v, PrimaryRoutes=%v", n.name, st.Self.TailscaleIPs, st.BackendState, selfAllowed, selfPrimary) + for _, peer := range st.Peer { + var aips []string + if peer.AllowedIPs != nil { + for i := range peer.AllowedIPs.Len() { + aips = append(aips, peer.AllowedIPs.At(i).String()) + } + } + e.t.Logf("[%s] peer %s (%s): AllowedIPs=%v, Online=%v, Relay=%q, CurAddr=%q", + n.name, peer.HostName, peer.TailscaleIPs, + aips, peer.Online, peer.Relay, peer.CurAddr) + } +} + +// waitForPeerRoute polls the node's status until it sees the given route prefix +// in a peer's AllowedIPs, or until timeout. Returns true if found. +func (e *Env) waitForPeerRoute(n *Node, prefix string, timeout time.Duration) bool { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + for { + st, err := n.agent.Status(ctx) + if err != nil { + return false + } + for _, peer := range st.Peer { + if peer.AllowedIPs != nil { + for i := range peer.AllowedIPs.Len() { + if peer.AllowedIPs.At(i).String() == prefix { + return true + } + } + } + } + if ctx.Err() != nil { + return false + } + time.Sleep(time.Second) + } +} + +// HTTPGet makes an HTTP GET request from the given node to the specified URL. +// The request is proxied through TTA's /http-get handler. +func (e *Env) HTTPGet(from *Node, targetURL string) string { + for attempt := range 3 { + ctx, cancel := context.WithTimeout(context.Background(), 6*time.Second) + reqURL := "http://unused/http-get?url=" + targetURL + req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil) + if err != nil { + cancel() + e.t.Fatalf("HTTPGet: %v", err) + } + res, err := from.agent.HTTPClient.Do(req) + cancel() + if err != nil { + e.logVerbosef("HTTPGet attempt %d from %s: %v", attempt+1, from.name, err) + continue + } + body, _ := io.ReadAll(res.Body) + res.Body.Close() + if res.StatusCode == http.StatusBadGateway || res.StatusCode == http.StatusServiceUnavailable { + e.t.Logf("HTTPGet attempt %d from %s: status %d, body: %s", attempt+1, from.name, res.StatusCode, string(body)) + time.Sleep(2 * time.Second) + continue + } + return string(body) + } + e.t.Fatalf("HTTPGet from %s to %s: all attempts failed", from.name, targetURL) + return "" +} + +// setNodeScreenshot stores the latest screenshot data URI for a node. +func (e *Env) setNodeScreenshot(name, dataURI string) { + e.nodeStatusMu.Lock() + if ns := e.nodeStatus[name]; ns != nil { + ns.Screenshot = dataURI + } + e.nodeStatusMu.Unlock() +} + +// setNodeScreenshotPort stores the Host.app screenshot server port for a node. +func (e *Env) setNodeScreenshotPort(name string, port int) { + e.nodeStatusMu.Lock() + if ns := e.nodeStatus[name]; ns != nil { + ns.ScreenshotPort = port + } + e.nodeStatusMu.Unlock() +} + +// nodeScreenshotPort returns the Host.app screenshot server port for a node, or 0. +func (e *Env) nodeScreenshotPort(name string) int { + e.nodeStatusMu.Lock() + defer e.nodeStatusMu.Unlock() + if ns := e.nodeStatus[name]; ns != nil { + return ns.ScreenshotPort + } + return 0 +} + +// initVnet creates the vnet server. Called once via sync.Once. +func (e *Env) initVnet() { + e.vnetOnce.Do(func() { + var err error + e.server, err = vnet.New(&e.cfg) + if err != nil { + e.t.Fatalf("vnet.New: %v", err) + } + e.t.Cleanup(func() { e.server.Close() }) + + e.server.SetDHCPCallback(func(mac vnet.MAC, nodeNum int, msgType layers.DHCPMsgType, ip netip.Addr) { + name := e.nodeNameByNum(nodeNum) + nicIdx := e.nicIndexForMAC(name, mac) + ipStr := ip.String() + switch msgType { + case layers.DHCPMsgTypeDiscover: + e.setNodeDHCP(name, nicIdx, "Discover sent") + e.eventBus.Publish(VMEvent{NodeName: name, Type: EventDHCPDiscover, Message: "DHCP Discover sent", NIC: nicIdx}) + case layers.DHCPMsgTypeOffer: + e.setNodeDHCP(name, nicIdx, "Offered "+ipStr) + e.eventBus.Publish(VMEvent{NodeName: name, Type: EventDHCPOffer, Message: "DHCP Offer received", Detail: ipStr, NIC: nicIdx}) + case layers.DHCPMsgTypeRequest: + e.setNodeDHCP(name, nicIdx, "Requesting "+ipStr) + e.eventBus.Publish(VMEvent{NodeName: name, Type: EventDHCPRequest, Message: "DHCP Request sent", Detail: ipStr, NIC: nicIdx}) + case layers.DHCPMsgTypeAck: + e.setNodeDHCP(name, nicIdx, "Got "+ipStr) + e.eventBus.Publish(VMEvent{NodeName: name, Type: EventDHCPAck, Message: "DHCP Ack: got " + ipStr, Detail: ipStr, NIC: nicIdx}) + } + }) + + if e.sameTailnetUser { + e.server.ControlServer().AllNodesSameUser = true + } + if e.allOnline { + e.server.ControlServer().AllOnline = true + } + }) +} + +// ensureQEMUSocket creates the Unix stream socket for QEMU VMs. Called once. +func (e *Env) ensureQEMUSocket() { + e.qemuSockOnce.Do(func() { + e.initVnet() + e.sockAddr = filepath.Join(e.tempDir, "vnet.sock") + srv, err := net.Listen("unix", e.sockAddr) + if err != nil { + e.t.Fatalf("listen unix: %v", err) + } + e.t.Cleanup(func() { srv.Close() }) + go func() { + for { + c, err := srv.Accept() + if err != nil { + return + } + go e.server.ServeUnixConn(c.(*net.UnixConn), vnet.ProtocolQEMU) + } + }() + }) +} + +// ensureDgramSocket creates the Unix dgram socket for macOS VMs. Called once. +func (e *Env) ensureDgramSocket() { + e.dgramSockOnce.Do(func() { + e.initVnet() + e.dgramSockAddr = fmt.Sprintf("/tmp/vmtest-dgram-%d.sock", os.Getpid()) + e.t.Cleanup(func() { os.Remove(e.dgramSockAddr) }) + dgramAddr, err := net.ResolveUnixAddr("unixgram", e.dgramSockAddr) + if err != nil { + e.t.Fatalf("resolve dgram addr: %v", err) + } + uc, err := net.ListenUnixgram("unixgram", dgramAddr) + if err != nil { + e.t.Fatalf("listen unixgram: %v", err) + } + e.t.Cleanup(func() { uc.Close() }) + go e.server.ServeUnixConn(uc, vnet.ProtocolUnixDGRAM) + }) +} + +// ensureCompiled compiles binaries for the given platform and registers them +// with the vnet file server. Safe for concurrent use; only compiles once per platform. +func (e *Env) ensureCompiled(ctx context.Context, goos, goarch string) { + key := goos + "_" + goarch + + e.compileMu.Lock() + once, ok := e.compileOnce[key] + if !ok { + once = new(sync.Once) + mak.Set(&e.compileOnce, key, once) + } + e.compileMu.Unlock() + + once.Do(func() { + step := e.Step(fmt.Sprintf("Compile %s_%s binaries", goos, goarch)) + step.Begin() + if err := e.compileBinariesForOS(ctx, goos, goarch); err != nil { + step.End(err) + e.t.Fatalf("compileBinariesForOS(%s, %s): %v", goos, goarch, err) + } + step.End(nil) + e.registerBinaries(goos, goarch) + }) +} + +// ensureImage prepares the cloud image for os and returns any error from the +// preparation. Safe for concurrent use; only prepares once per OS name. +func (e *Env) ensureImage(ctx context.Context, os OSImage) error { + e.compileMu.Lock() + once, ok := e.imageOnce[os.Name] + if !ok { + once = new(sync.Once) + mak.Set(&e.imageOnce, os.Name, once) + } + e.compileMu.Unlock() + + var err error + once.Do(func() { + step := e.Step(fmt.Sprintf("Prepare %s image", os.Name)) + step.Begin() + err = ensureImage(ctx, os) + step.End(err) + }) + return err +} + +// registerBinaries registers compiled binaries with the vnet file server. +// Safe for concurrent use. +func (e *Env) registerBinaries(goos, goarch string) { + e.initVnet() + dir := goos + "_" + goarch + for _, name := range []string{"tta", "tailscale", "tailscaled"} { + data, err := os.ReadFile(filepath.Join(e.binDir, dir, name)) + if err != nil { + e.t.Fatalf("reading compiled %s/%s: %v", dir, name, err) + } + e.server.RegisterFile(dir+"/"+name, data) + } +} + +// waitForAgentConn waits for a TTA agent to connect by issuing a simple +// HTTP GET to the root endpoint, without requiring tailscaled. +func (e *Env) waitForAgentConn(ctx context.Context, n *Node) error { + for { + reqCtx, cancel := context.WithTimeout(ctx, 2*time.Second) + req, err := http.NewRequestWithContext(reqCtx, "GET", "http://unused/", nil) + if err != nil { + cancel() + return err + } + res, err := n.agent.HTTPClient.Do(req) + cancel() + if err == nil { + res.Body.Close() + return nil + } + if ctx.Err() != nil { + return ctx.Err() + } + time.Sleep(500 * time.Millisecond) + } +} + +// Agent returns the node's TTA agent client, or nil if NoAgent is set. +func (n *Node) Agent() *vnet.NodeAgentClient { + return n.agent +} + +// LANPing pings a LAN IP from the given node using TTA's /ping endpoint. +// It retries for up to 2 minutes, which is enough for a macOS VM to boot +// and acquire a DHCP lease. +func (e *Env) LANPing(from *Node, targetIP netip.Addr) { + if from.agent == nil { + e.t.Fatalf("LANPing: node %s has no agent (NoAgent set?)", from.name) + } + e.t.Logf("LANPing: %s -> %s", from.name, targetIP) + deadline := time.Now().Add(2 * time.Minute) + for attempt := 0; time.Now().Before(deadline); attempt++ { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + reqURL := fmt.Sprintf("http://unused/ping?host=%s", targetIP) + req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil) + if err != nil { + cancel() + e.t.Fatalf("LANPing: %v", err) + } + res, err := from.agent.HTTPClient.Do(req) + cancel() + if err != nil { + if attempt%10 == 0 { + e.t.Logf("LANPing attempt %d: %v", attempt+1, err) + } + time.Sleep(2 * time.Second) + continue + } + body, _ := io.ReadAll(res.Body) + res.Body.Close() + if res.StatusCode == 200 { + e.t.Logf("LANPing: %s -> %s succeeded on attempt %d", from.name, targetIP, attempt+1) + return + } + if attempt%10 == 0 { + e.t.Logf("LANPing attempt %d: status %d, body: %s", attempt+1, res.StatusCode, string(body)) + } + time.Sleep(2 * time.Second) + } + e.t.Fatalf("LANPing: %s -> %s timed out after 2 minutes", from.name, targetIP) +} + +// SendTaildropFile sends a file via Taildrop from one node to another. +// The to node must be on the tailnet. It fatals on error. +func (e *Env) SendTaildropFile(from, to *Node, name string, content []byte) { + e.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + st, err := to.agent.Status(ctx) + if err != nil { + e.t.Fatalf("SendTaildropFile: status for %s: %v", to.name, err) + } + if len(st.Self.TailscaleIPs) == 0 { + e.t.Fatalf("SendTaildropFile: %s has no Tailscale IPs", to.name) + } + target := st.Self.TailscaleIPs[0].String() + + reqURL := fmt.Sprintf("http://unused/taildrop-send?to=%s&name=%s", target, name) + req, err := http.NewRequestWithContext(ctx, "POST", reqURL, bytes.NewReader(content)) + if err != nil { + e.t.Fatalf("SendTaildropFile: %v", err) + } + res, err := from.agent.HTTPClient.Do(req) + if err != nil { + e.t.Fatalf("SendTaildropFile(%s -> %s): %v", from.name, to.name, err) + } + defer res.Body.Close() + body, _ := io.ReadAll(res.Body) + if res.StatusCode != 200 { + e.t.Fatalf("SendTaildropFile(%s -> %s): %s: %s", from.name, to.name, res.Status, body) + } + if msg := strings.TrimSpace(string(body)); msg != "" { + e.t.Logf("[%s] %s", from.name, msg) + } + e.t.Logf("[%s] sent Taildrop %q (%d bytes) to %s", from.name, name, len(content), to.name) +} + +// RecvTaildropFile waits for an incoming Taildrop file on the node and +// returns the filename and contents. The provided context bounds the wait; +// in addition, RecvTaildropFile imposes its own 90s upper bound. It fatals +// on error or timeout. +func (e *Env) RecvTaildropFile(ctx context.Context, n *Node) (name string, content []byte) { + e.t.Helper() + ctx, cancel := context.WithTimeout(ctx, 90*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "GET", "http://unused/taildrop-recv", nil) + if err != nil { + e.t.Fatalf("RecvTaildropFile: %v", err) + } + res, err := n.agent.HTTPClient.Do(req) + if err != nil { + e.t.Fatalf("RecvTaildropFile(%s): %v", n.name, err) + } + defer res.Body.Close() + body, _ := io.ReadAll(res.Body) + if res.StatusCode != 200 { + e.t.Fatalf("RecvTaildropFile(%s): %s: %s", n.name, res.Status, body) + } + name = res.Header.Get("Taildrop-Filename") + e.t.Logf("[%s] received Taildrop %q (%d bytes)", n.name, name, len(body)) + return name, body +} + +var buildGokrazy sync.Once + +// ensureGokrazy builds the gokrazy base image (once per test process) and +// locates the kernel. The build is fast (~4s) so we always rebuild to ensure +// the baked-in binaries (tta, tailscale, tailscaled) match the current source. +func (e *Env) ensureGokrazy(ctx context.Context) error { + if e.gokrazyBase != "" { + return nil // already found + } + + modRoot, err := findModRoot() + if err != nil { + return err + } + + var buildErr error + buildGokrazy.Do(func() { + e.t.Logf("building gokrazy natlab image...") + cmd := exec.CommandContext(ctx, "make", "natlab") + cmd.Dir = filepath.Join(modRoot, "gokrazy") + cmd.Stderr = os.Stderr + cmd.Stdout = os.Stdout + if err := cmd.Run(); err != nil { + buildErr = fmt.Errorf("make natlab: %w", err) + } + }) + if buildErr != nil { + return buildErr + } + + e.gokrazyBase = filepath.Join(modRoot, "gokrazy/natlabapp.qcow2") + + kernel, err := findKernelPath(filepath.Join(modRoot, "go.mod")) + if err != nil { + return fmt.Errorf("finding kernel: %w", err) + } + e.gokrazyKernel = kernel + return nil +} + +// compileBinariesForOS prepares the tta, tailscale, and tailscaled binaries +// for the given GOOS/GOARCH and places them in e.binDir/_/. +// +// tta is always built from the local source tree (the test agent must match +// the test framework). When --test-version is set, tailscale and tailscaled +// are taken from the downloaded release tarball instead of being compiled +// from source. +func (e *Env) compileBinariesForOS(ctx context.Context, goos, goarch string) error { + modRoot, err := findModRoot() + if err != nil { + return err + } + + dir := goos + "_" + goarch + outDir := filepath.Join(e.binDir, dir) + if err := os.MkdirAll(outDir, 0755); err != nil { + return err + } + + // Use downloaded release binaries only on Linux: pkgs.tailscale.com only + // publishes Linux tarballs, so other GOOS values still build from source. + useDownloaded := e.testVersion != "" && goos == "linux" + + type binary struct{ name, pkg string } + buildBins := []binary{{"tta", "./cmd/tta"}} + if !useDownloaded { + buildBins = append(buildBins, + binary{"tailscale", "./cmd/tailscale"}, + binary{"tailscaled", "./cmd/tailscaled"}) + } + + var eg errgroup.Group + for _, bin := range buildBins { + eg.Go(func() error { + outPath := filepath.Join(outDir, bin.name) + e.t.Logf("compiling %s/%s...", dir, bin.name) + cmd := exec.CommandContext(ctx, "go", "build", "-o", outPath, bin.pkg) + cmd.Dir = modRoot + cmd.Env = append(os.Environ(), "GOOS="+goos, "GOARCH="+goarch, "CGO_ENABLED=0") + if out, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("building %s/%s: %v\n%s", dir, bin.name, err, out) + } + e.t.Logf("compiled %s/%s", dir, bin.name) + return nil + }) + } + + if useDownloaded { + eg.Go(func() error { + srcDir, err := ensureVersionBinaries(ctx, e.testVersion, goarch, e.t.Logf) + if err != nil { + return err + } + for _, name := range []string{"tailscale", "tailscaled"} { + if err := copyFile(filepath.Join(srcDir, name), filepath.Join(outDir, name), 0755); err != nil { + return fmt.Errorf("staging %s/%s: %w", dir, name, err) + } + } + e.t.Logf("staged version %s tailscale & tailscaled for %s", e.testVersion, dir) + return nil + }) + } + + return eg.Wait() +} + +// copyFile copies src to dst with the given permission bits. +func copyFile(src, dst string, perm os.FileMode) error { + in, err := os.Open(src) + if err != nil { + return err + } + defer in.Close() + return writeAtomic(dst, in, perm) +} + +// findModRoot returns the root of the Go module (where go.mod is). +func findModRoot() (string, error) { + out, err := exec.Command("go", "env", "GOMOD").CombinedOutput() + if err != nil { + return "", fmt.Errorf("go env GOMOD: %w", err) + } + gomod := strings.TrimSpace(string(out)) + if gomod == "" || gomod == os.DevNull { + return "", fmt.Errorf("not in a Go module") + } + return filepath.Dir(gomod), nil +} + +// findKernelPath finds the gokrazy kernel vmlinuz path from go.mod. +func findKernelPath(goMod string) (string, error) { + // Import the same logic as nat_test.go. + b, err := os.ReadFile(goMod) + if err != nil { + return "", err + } + + goModCacheB, err := exec.Command("go", "env", "GOMODCACHE").CombinedOutput() + if err != nil { + return "", err + } + goModCache := strings.TrimSpace(string(goModCacheB)) + + // Parse go.mod to find gokrazy-kernel version. + for _, line := range strings.Split(string(b), "\n") { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "github.com/tailscale/gokrazy-kernel") { + parts := strings.Fields(line) + if len(parts) >= 2 { + return filepath.Join(goModCache, parts[0]+"@"+parts[1], "vmlinuz"), nil + } + } + } + return "", fmt.Errorf("gokrazy-kernel not found in %s", goMod) +} diff --git a/tstest/natlab/vmtest/vmtest_test.go b/tstest/natlab/vmtest/vmtest_test.go new file mode 100644 index 0000000000000..cadf570d15c19 --- /dev/null +++ b/tstest/natlab/vmtest/vmtest_test.go @@ -0,0 +1,980 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package vmtest_test + +import ( + "bytes" + "fmt" + "net/netip" + "strings" + "testing" + "time" + + "tailscale.com/client/local" + "tailscale.com/tailcfg" + "tailscale.com/tstest" + "tailscale.com/tstest/integration/testcontrol" + "tailscale.com/tstest/natlab/vmtest" + "tailscale.com/tstest/natlab/vnet" + "tailscale.com/types/key" + "tailscale.com/types/netmap" +) + +func TestMacOSAndLinuxCanPing(t *testing.T) { + env := vmtest.New(t) + + lan := env.AddNetwork("192.168.1.1/24") + + linux := env.AddNode("linux", lan, + vmtest.OS(vmtest.Gokrazy), + vmtest.DontJoinTailnet()) + macos := env.AddNode("macos", lan, + vmtest.OS(vmtest.MacOS), + vmtest.DontJoinTailnet()) + + env.Start() + + env.LANPing(linux, macos.LanIP(lan)) +} + +func TestTwoMacOSVMsCanPing(t *testing.T) { + env := vmtest.New(t) + + lan := env.AddNetwork("192.168.1.1/24") + + mac1 := env.AddNode("mac1", lan, + vmtest.OS(vmtest.MacOS), + vmtest.DontJoinTailnet()) + mac2 := env.AddNode("mac2", lan, + vmtest.OS(vmtest.MacOS), + vmtest.DontJoinTailnet()) + + env.Start() + + // Both macOS VMs have TTA. Ping from mac1 to mac2 and vice versa. + env.LANPing(mac1, mac2.LanIP(lan)) + env.LANPing(mac2, mac1.LanIP(lan)) +} + +func TestSubnetRouter(t *testing.T) { + testSubnetRouterForOS(t, vmtest.Ubuntu2404) +} + +func TestSubnetRouterFreeBSD(t *testing.T) { + testSubnetRouterForOS(t, vmtest.FreeBSD150) +} + +func testSubnetRouterForOS(t testing.TB, srOS vmtest.OSImage) { + t.Helper() + env := vmtest.New(t) + + clientNet := env.AddNetwork("2.1.1.1", "192.168.1.1/24", "2000:1::1/64", vnet.EasyNAT) + internalNet := env.AddNetwork("10.0.0.1/24", "2000:2::1/64") + + client := env.AddNode("client", clientNet, + vmtest.OS(vmtest.Gokrazy)) + sr := env.AddNode("subnet-router", clientNet, internalNet, + vmtest.OS(srOS), + vmtest.AdvertiseRoutes("10.0.0.0/24")) + backend := env.AddNode("backend", internalNet, + vmtest.OS(vmtest.Gokrazy), + vmtest.DontJoinTailnet(), + vmtest.WebServer(8080)) + + // Declare test-specific steps for the web UI. + approveStep := env.AddStep("Approve subnet routes") + httpStep := env.AddStep("HTTP GET through subnet router") + + env.Start() + + approveStep.Begin() + env.ApproveRoutes(sr, "10.0.0.0/24") + approveStep.End(nil) + + httpStep.Begin() + body := env.HTTPGet(client, fmt.Sprintf("http://%s:8080/", backend.LanIP(internalNet))) + if !strings.Contains(body, "Hello world I am backend") { + httpStep.End(fmt.Errorf("got %q", body)) + t.Fatalf("got %q", body) + } + httpStep.End(nil) +} + +func TestSiteToSite(t *testing.T) { + testSiteToSite(t, vmtest.Ubuntu2404) +} + +// testSiteToSite runs a site-to-site subnet routing test with +// --snat-subnet-routes=false, verifying that original source IPs are preserved +// across Tailscale subnet routes. +// +// Topology: +// +// Site A: backend-a (10.1.0.0/24) ← → sr-a (WAN + LAN-A) +// Site B: backend-b (10.2.0.0/24) ← → sr-b (WAN + LAN-B) +// +// Both subnet routers are on Tailscale with --snat-subnet-routes=false. +// The test sends HTTP from backend-a to backend-b through the subnet routers +// and verifies that backend-b sees backend-a's LAN IP (not the subnet router's). +func testSiteToSite(t *testing.T, srOS vmtest.OSImage) { + env := vmtest.New(t) + + // WAN networks for each site (each behind NAT). + wanA := env.AddNetwork("2.1.1.1", "192.168.1.1/24", vnet.EasyNAT) + wanB := env.AddNetwork("3.1.1.1", "192.168.2.1/24", vnet.EasyNAT) + + // Internal LAN for each site. + lanA := env.AddNetwork("10.1.0.1/24") + lanB := env.AddNetwork("10.2.0.1/24") + + // Subnet routers: each on its WAN + LAN, advertising the local LAN, + // with SNAT disabled to preserve source IPs. + srA := env.AddNode("sr-a", wanA, lanA, + vmtest.OS(srOS), + vmtest.AdvertiseRoutes("10.1.0.0/24"), + vmtest.SNATSubnetRoutes(false)) + srB := env.AddNode("sr-b", wanB, lanB, + vmtest.OS(srOS), + vmtest.AdvertiseRoutes("10.2.0.0/24"), + vmtest.SNATSubnetRoutes(false)) + + // Backend servers on each site's LAN (not on Tailscale). + // Use Ubuntu so we can SSH in to add static routes. + backendA := env.AddNode("backend-a", lanA, + vmtest.OS(vmtest.Ubuntu2404), + vmtest.DontJoinTailnet(), + vmtest.WebServer(8080)) + backendB := env.AddNode("backend-b", lanB, + vmtest.OS(vmtest.Ubuntu2404), + vmtest.DontJoinTailnet(), + vmtest.WebServer(8080)) + + // Declare test-specific steps for the web UI. + approveStep := env.AddStep("Approve subnet routes (sr-a, sr-b)") + staticRouteStep := env.AddStep("Add static routes on backends") + httpStep := env.AddStep("HTTP GET through site-to-site") + + env.Start() + + approveStep.Begin() + env.ApproveRoutes(srA, "10.1.0.0/24") + env.ApproveRoutes(srB, "10.2.0.0/24") + approveStep.End(nil) + + // Add static routes on the backends so that traffic to the remote site's + // subnet goes through the local subnet router. This mirrors how a real + // site-to-site deployment is configured. + srALanIP := srA.LanIP(lanA).String() + srBLanIP := srB.LanIP(lanB).String() + t.Logf("sr-a LAN IP: %s, sr-b LAN IP: %s", srALanIP, srBLanIP) + t.Logf("backend-a LAN IP: %s, backend-b LAN IP: %s", backendA.LanIP(lanA), backendB.LanIP(lanB)) + + staticRouteStep.Begin() + env.AddRoute(backendA, "10.2.0.0/24", srALanIP) + env.AddRoute(backendB, "10.1.0.0/24", srBLanIP) + staticRouteStep.End(nil) + + // Make an HTTP request from backend-a to backend-b through the subnet routers. + // TTA's /http-get falls back to direct dial on non-Tailscale nodes. + httpStep.Begin() + backendBIP := backendB.LanIP(lanB) + body := env.HTTPGet(backendA, fmt.Sprintf("http://%s:8080/", backendBIP)) + t.Logf("response: %s", body) + + if !strings.Contains(body, "Hello world I am backend-b") { + httpStep.End(fmt.Errorf("expected response from backend-b, got %q", body)) + t.Fatalf("expected response from backend-b, got %q", body) + } + + // Verify the source IP was preserved. With --snat-subnet-routes=false, + // backend-b should see backend-a's LAN IP as the source, not sr-b's LAN IP. + backendAIP := backendA.LanIP(lanA).String() + if !strings.Contains(body, "from "+backendAIP) { + httpStep.End(fmt.Errorf("source IP not preserved: expected %q in response, got %q", backendAIP, body)) + t.Fatalf("source IP not preserved: expected %q in response, got %q", backendAIP, body) + } + httpStep.End(nil) +} + +// TestInterNetworkTCP verifies that vnet routes raw TCP between simulated +// networks: a non-Tailscale VM on one NAT'd LAN can reach a webserver on a +// different network using a 1:1 NAT, and the webserver sees the client's +// network's WAN IP as the source (post-NAT). +func TestInterNetworkTCP(t *testing.T) { + env := vmtest.New(t) + + const ( + clientWAN = "1.0.0.1" + webWAN = "5.0.0.1" + ) + + clientNet := env.AddNetwork(clientWAN, "192.168.1.1/24", vnet.EasyNAT) + webNet := env.AddNetwork(webWAN, "192.168.5.1/24", vnet.One2OneNAT) + + client := env.AddNode("client", clientNet, + vmtest.OS(vmtest.Gokrazy), + vmtest.DontJoinTailnet()) + env.AddNode("webserver", webNet, + vmtest.OS(vmtest.Gokrazy), + vmtest.DontJoinTailnet(), + vmtest.WebServer(8080)) + + // Declare test-specific steps for the web UI. + httpStep := env.AddStep("HTTP GET across networks via NAT") + + env.Start() + + httpStep.Begin() + body := env.HTTPGet(client, fmt.Sprintf("http://%s:8080/", webWAN)) + t.Logf("response: %s", body) + if !strings.Contains(body, "Hello world I am webserver") { + httpStep.End(fmt.Errorf("unexpected response: %q", body)) + t.Fatalf("unexpected response: %q", body) + } + if !strings.Contains(body, "from "+clientWAN) { + httpStep.End(fmt.Errorf("expected source %q in response, got %q", clientWAN, body)) + t.Fatalf("expected source %q in response, got %q", clientWAN, body) + } + httpStep.End(nil) +} + +// TestSubnetRouterPublicIP verifies that toggling --accept-routes on the +// client switches between dialing a webserver directly and routing through a +// subnet router that advertises the webserver's public IP range. +// +// Topology: client, subnet router, and webserver each live behind their own +// NAT'd network with distinct WAN IPs; the subnet router advertises the +// webserver's network as a route. The webserver echoes the source IP it +// sees: +// - accept-routes=off: client dials webserver directly; source is client's WAN. +// - accept-routes=on: client tunnels to the subnet router, which forwards +// and SNATs; source is subnet router's WAN. +func TestSubnetRouterPublicIP(t *testing.T) { + env := vmtest.New(t) + + const ( + clientWAN = "1.0.0.1" + routerWAN = "2.0.0.1" + webWAN = "5.0.0.1" + webRoute = "5.0.0.0/24" + ) + + clientNet := env.AddNetwork(clientWAN, "192.168.1.1/24", vnet.EasyNAT) + routerNet := env.AddNetwork(routerWAN, "192.168.2.1/24", vnet.EasyNAT) + webNet := env.AddNetwork(webWAN, "192.168.5.1/24", vnet.One2OneNAT) + + client := env.AddNode("client", clientNet, + vmtest.OS(vmtest.Gokrazy)) + sr := env.AddNode("subnet-router", routerNet, + vmtest.OS(vmtest.Gokrazy), + vmtest.AdvertiseRoutes(webRoute)) + env.AddNode("webserver", webNet, + vmtest.OS(vmtest.Gokrazy), + vmtest.DontJoinTailnet(), + vmtest.WebServer(8080)) + + // Declare test-specific steps for the web UI. + approveStep := env.AddStep("Approve subnet route (public IP)") + checkOn1Step := env.AddStep("HTTP GET (accept-routes=on)") + checkOffStep := env.AddStep("HTTP GET (accept-routes=off)") + checkOn2Step := env.AddStep("HTTP GET (accept-routes=on, again)") + + env.Start() + // ApproveRoutes also turns on RouteAll on the client. + approveStep.Begin() + env.ApproveRoutes(sr, webRoute) + approveStep.End(nil) + + webURL := fmt.Sprintf("http://%s:8080/", webWAN) + check := func(step *vmtest.Step, label, wantSrc string) { + t.Helper() + step.Begin() + body := env.HTTPGet(client, webURL) + t.Logf("[%s] response: %s", label, body) + if !strings.Contains(body, "Hello world I am webserver") { + step.End(fmt.Errorf("[%s] unexpected webserver response: %q", label, body)) + t.Fatalf("[%s] unexpected webserver response: %q", label, body) + } + if !strings.Contains(body, "from "+wantSrc) { + step.End(fmt.Errorf("[%s] expected source %q in response, got %q", label, wantSrc, body)) + t.Fatalf("[%s] expected source %q in response, got %q", label, wantSrc, body) + } + step.End(nil) + } + + // accept-routes=on (set by ApproveRoutes): traffic flows via the subnet router. + check(checkOn1Step, "accept-routes=on", routerWAN) + + // accept-routes=off: client dials the webserver directly. + env.SetAcceptRoutes(client, false) + check(checkOffStep, "accept-routes=off", clientWAN) + + // Toggle back on to confirm the transition works in both directions. + env.SetAcceptRoutes(client, true) + check(checkOn2Step, "accept-routes=on (again)", routerWAN) +} + +// TestSubnetRouterAndExitNode checks how the subnet router and exit node +// preferences interact. Topology: client, subnet router, exit node, and +// webserver, each on its own NAT'd network with distinct WAN IPs. The subnet +// router advertises the webserver's network (5.0.0.0/24); the exit node +// advertises 0.0.0.0/0 + ::/0. The webserver echoes the source IP it sees: +// +// exit=off, subnet=off → client's WAN (direct dial) +// exit=off, subnet=on → subnet router's WAN +// exit=on, subnet=off → exit node's WAN +// exit=on, subnet=on → subnet router's WAN (more-specific /24 beats /0) +func TestSubnetRouterAndExitNode(t *testing.T) { + env := vmtest.New(t) + + const ( + clientWAN = "1.0.0.1" + routerWAN = "2.0.0.1" + exitWAN = "3.0.0.1" + webWAN = "5.0.0.1" + webRoute = "5.0.0.0/24" + ) + + clientNet := env.AddNetwork(clientWAN, "192.168.1.1/24", vnet.EasyNAT) + routerNet := env.AddNetwork(routerWAN, "192.168.2.1/24", vnet.EasyNAT) + exitNet := env.AddNetwork(exitWAN, "192.168.3.1/24", vnet.EasyNAT) + webNet := env.AddNetwork(webWAN, "192.168.5.1/24", vnet.One2OneNAT) + + client := env.AddNode("client", clientNet, + vmtest.OS(vmtest.Gokrazy)) + sr := env.AddNode("subnet-router", routerNet, + vmtest.OS(vmtest.Gokrazy), + vmtest.AdvertiseRoutes(webRoute)) + exit := env.AddNode("exit", exitNet, + vmtest.OS(vmtest.Gokrazy), + vmtest.AdvertiseRoutes("0.0.0.0/0,::/0")) + env.AddNode("webserver", webNet, + vmtest.OS(vmtest.Gokrazy), + vmtest.DontJoinTailnet(), + vmtest.WebServer(8080)) + + // Declare test-specific steps for the web UI. + approveStep := env.AddStep("Approve subnet & exit routes") + + webURL := fmt.Sprintf("http://%s:8080/", webWAN) + tests := []struct { + name string // subtest name; describes (exit, subnet) toggles + exit *vmtest.Node + subnet bool + wantSrc string + step *vmtest.Step + }{ + {"exit-off,subnet-off", nil, false, clientWAN, nil}, + {"exit-off,subnet-on", nil, true, routerWAN, nil}, + {"exit-on,subnet-off", exit, false, exitWAN, nil}, + // More-specific 5.0.0.0/24 from sr beats 0.0.0.0/0 from exit. + {"exit-on,subnet-on", exit, true, routerWAN, nil}, + } + for i := range tests { + tests[i].step = env.AddStep("HTTP GET: " + tests[i].name) + } + + env.Start() + approveStep.Begin() + env.ApproveRoutes(sr, webRoute) + env.ApproveRoutes(exit, "0.0.0.0/0", "::/0") + // Don't let the exit node itself forward via the subnet router: when the + // client is using the exit node only, we want the exit node to egress to + // the simulated internet directly so the webserver sees the exit's WAN. + env.SetAcceptRoutes(exit, false) + approveStep.End(nil) + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + tc.step.Begin() + env.SetExitNode(client, tc.exit) + env.SetAcceptRoutes(client, tc.subnet) + body := env.HTTPGet(client, webURL) + t.Logf("response: %s", body) + if !strings.Contains(body, "Hello world I am webserver") { + tc.step.End(fmt.Errorf("unexpected webserver response: %q", body)) + t.Fatalf("unexpected webserver response: %q", body) + } + if !strings.Contains(body, "from "+tc.wantSrc) { + tc.step.End(fmt.Errorf("expected source %q in response, got %q", tc.wantSrc, body)) + t.Fatalf("expected source %q in response, got %q", tc.wantSrc, body) + } + tc.step.End(nil) + }) + } +} + +// TestTaildrop verifies that one Ubuntu node can send a file to another +// Ubuntu node via Taildrop, and the receiver gets the same content. +// +// Topology: two Ubuntu nodes, each behind its own EasyNAT, both joined to the +// tailnet. The sender runs `tailscale file cp` to push to the receiver's +// Tailscale IP; the receiver then runs `tailscale file get --wait` to fetch +// it. +func TestTaildrop(t *testing.T) { + env := vmtest.New(t, vmtest.SameTailnetUser()) + + senderNet := env.AddNetwork("1.0.0.1", "192.168.1.1/24", vnet.EasyNAT) + receiverNet := env.AddNetwork("2.0.0.1", "192.168.2.1/24", vnet.EasyNAT) + + sender := env.AddNode("sender", senderNet, + vmtest.OS(vmtest.Ubuntu2404)) + receiver := env.AddNode("receiver", receiverNet, + vmtest.OS(vmtest.Ubuntu2404)) + + // Declare test-specific steps for the web UI. + sendStep := env.AddStep("Taildrop send (sender -> receiver)") + recvStep := env.AddStep("Taildrop receive (on receiver)") + verifyStep := env.AddStep("Verify received name and contents") + + env.Start() + + const filename = "hello.txt" + want := []byte("hello world this is a Taildrop test\n") + + sendStep.Begin() + env.SendTaildropFile(sender, receiver, filename, want) + sendStep.End(nil) + + recvStep.Begin() + gotName, gotContent := env.RecvTaildropFile(t.Context(), receiver) + recvStep.End(nil) + + verifyStep.Begin() + if gotName != filename { + err := fmt.Errorf("received name = %q; want %q", gotName, filename) + verifyStep.End(err) + t.Error(err) + return + } + if !bytes.Equal(gotContent, want) { + err := fmt.Errorf("received content = %q; want %q", gotContent, want) + verifyStep.End(err) + t.Error(err) + return + } + verifyStep.End(nil) +} + +// TestExitNode verifies that switching the client's exit node setting between +// off, exit1, and exit2 correctly routes the client's internet traffic. +// +// Topology: each of the client and the two exit nodes lives behind its own NAT +// with a unique WAN IP, and a webserver lives on yet another network using a +// 1:1 NAT so it's reachable from the simulated internet at a stable address. +// The webserver echoes the source IP of incoming requests, so we can tell +// which network's NAT the client's traffic egressed through: +// - off: source is the client's network WAN IP. +// - exit1: source is exit1's network WAN IP. +// - exit2: source is exit2's network WAN IP. +func TestExitNode(t *testing.T) { + env := vmtest.New(t) + + const ( + clientWAN = "1.0.0.1" + exit1WAN = "2.0.0.1" + exit2WAN = "3.0.0.1" + webWAN = "5.0.0.1" + ) + + clientNet := env.AddNetwork(clientWAN, "192.168.1.1/24", vnet.EasyNAT) + exit1Net := env.AddNetwork(exit1WAN, "192.168.2.1/24", vnet.EasyNAT) + exit2Net := env.AddNetwork(exit2WAN, "192.168.3.1/24", vnet.EasyNAT) + webNet := env.AddNetwork(webWAN, "192.168.5.1/24", vnet.One2OneNAT) + + client := env.AddNode("client", clientNet, + vmtest.OS(vmtest.Gokrazy)) + exit1 := env.AddNode("exit1", exit1Net, + vmtest.OS(vmtest.Gokrazy), + vmtest.AdvertiseRoutes("0.0.0.0/0,::/0")) + exit2 := env.AddNode("exit2", exit2Net, + vmtest.OS(vmtest.Gokrazy), + vmtest.AdvertiseRoutes("0.0.0.0/0,::/0")) + env.AddNode("webserver", webNet, + vmtest.OS(vmtest.Gokrazy), + vmtest.DontJoinTailnet(), + vmtest.WebServer(8080)) + + // Declare test-specific steps for the web UI. + approveStep := env.AddStep("Approve exit-node routes (exit1, exit2)") + + webURL := fmt.Sprintf("http://%s:8080/", webWAN) + tests := []struct { + name string // subtest name + exit *vmtest.Node + wantSrc string + step *vmtest.Step + }{ + {"off", nil, clientWAN, nil}, + {"exit1", exit1, exit1WAN, nil}, + {"exit2", exit2, exit2WAN, nil}, + } + for i := range tests { + tests[i].step = env.AddStep("HTTP GET: exit=" + tests[i].name) + } + + env.Start() + approveStep.Begin() + env.ApproveRoutes(exit1, "0.0.0.0/0", "::/0") + env.ApproveRoutes(exit2, "0.0.0.0/0", "::/0") + approveStep.End(nil) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.step.Begin() + env.SetExitNode(client, tt.exit) + body := env.HTTPGet(client, webURL) + t.Logf("response: %s", body) + if !strings.Contains(body, "Hello world I am webserver") { + tt.step.End(fmt.Errorf("unexpected webserver response: %q", body)) + t.Fatalf("unexpected webserver response: %q", body) + } + if !strings.Contains(body, "from "+tt.wantSrc) { + tt.step.End(fmt.Errorf("expected source %q in response, got %q", tt.wantSrc, body)) + t.Fatalf("expected source %q in response, got %q", tt.wantSrc, body) + } + tt.step.End(nil) + }) + } +} + +// TestDiscoKeyChange verifies that when one node's disco key rotates without +// its WireGuard node key changing, peers detect the change, tear down stale +// WireGuard session state for that peer, and re-establish the tunnel in both +// directions. This exercises the disco-key-change handling that the +// bradfitz/rm_lazy_wg branch relies on for traffic to and from a peer whose +// magicsock state has been reset. +// +// Topology: two gokrazy nodes A and B, each on its own One2OneNAT network so +// every connection between them is a direct UDP path with no port-mapping or +// filtering. With NAT effects out of the way, what we measure here is the +// speed of disco-key-change reconciliation in wgengine/magicsock alone. The +// test control server is also configured with [testcontrol.Server.AllOnline] +// (via [vmtest.AllOnline]) so the controlclient/wgengine fast paths that +// branch on Online actually fire — without that flag the test exercises +// only the offline-peer code paths, which mask separate latent issues and +// are several seconds slower. +// +// The test runs four B-side rotations followed by a TSMP ping in the +// requested direction: +// +// rotate (LocalAPI rotate-disco-key) → ping B → A +// rotate (LocalAPI rotate-disco-key) → ping A → B +// restart (SIGKILL tailscaled) → ping B → A +// restart (SIGKILL tailscaled) → ping A → B +// +// Plus an initial A→B TSMP ping with a generous 30s budget to bring up the +// WireGuard tunnel before the rotations begin (so the post-rotation pings +// measure stale-state recovery, not first-time setup). All pings are TSMP +// because TSMP traverses the actual WireGuard data plane; PingDisco only +// exercises the magicsock disco layer and would mask any stale WG session +// problems. +// +// Two rotation methods are exercised: +// +// - LocalAPI rotate-disco-key (debug action): rolls B's magicsock disco +// private key in place, then bounces WantRunning to force wgengine to +// drop wireguard-go session keys for every peer (RotateDiscoKey alone +// only touches local disco state; without the WantRunning bounce, B +// keeps using stale per-peer session keys against A and A drops +// everything until B's WG rekey timer eventually fires). +// - SIGKILL of tailscaled (via TTA's /kill-tailscaled): the gokrazy +// supervisor respawns tailscaled, fully resetting B's magicsock and +// wgengine state in addition to rotating the disco key. +// +// Each post-rotation ping currently gets a 15-second budget. On a +// hypothetical perfect build it should take well under a second. In +// practice today there are two unavoidable multi-second waits: +// +// - The rotate-then-a→b phase on main takes ~10s for LazyWG. After +// B's WantRunning bounce, B's wgengine resets its sentActivityAt/ +// recvActivityAt maps and trims A out of the wireguard-go config +// as an "idle peer"; B only re-adds A on inbound activity, by +// which point A's first few TSMP packets have been silently +// dropped at B's tundev. The bradfitz/rm_lazy_wg branch removes +// that trimming entirely (verified locally), so this phase will +// drop to <100ms once that branch lands. +// +// - The restart phases take ~5s for the wireguard-go handshake retry +// timer. After SIGKILL+respawn the first WG handshake init from +// the restarted node sometimes goes into the void (likely the +// brief peer-removed window in the receiver's two-step +// [wgengine.userspaceEngine.maybeReconfigWireguardLocked] reconfig +// during which the peer is absent from wireguard-go), and wg-go's +// [device.RekeyTimeout] of 5s + jitter is the next opportunity to +// retry. That retry succeeds and the staged TSMP packet flushes. +// This is intrinsic to the protocol's retransmit policy. +// +// Once LazyWG is removed and the first-handshake-after-reconfig race +// is fixed, this budget should be tightened to 5s (or less). +// +// All four rotations also assert that B's WireGuard node key is unchanged. +func TestDiscoKeyChange(t *testing.T) { + // AllOnline makes the test control server mark every peer as Online=true + // in its MapResponses. Several disco-key handling fast paths + // (controlclient.removeUnwantedDiscoUpdates, + // removeUnwantedDiscoUpdatesFromFullNetmapUpdate, and the wgengine + // tsmpLearnedDisco fast path) only fire for online peers. Production + // control servers always populate Online; without this flag the test + // would only exercise the offline-peer paths. + env := vmtest.New(t, vmtest.AllOnline()) + + // One2OneNAT so each node has a 1:1 mapping to a public WAN IP with no + // port-translation or address-port filtering. This makes A↔B traffic + // behave like two unfirewalled hosts on the public internet, so any + // slowness we observe in this test cannot be blamed on NAT traversal. + aNet := env.AddNetwork("1.0.0.1", "192.168.1.1/24", vnet.One2OneNAT) + bNet := env.AddNetwork("2.0.0.1", "192.168.2.1/24", vnet.One2OneNAT) + + a := env.AddNode("a", aNet, vmtest.OS(vmtest.Gokrazy)) + b := env.AddNode("b", bNet, vmtest.OS(vmtest.Gokrazy)) + + type phase struct { + name string + rotate func() + pingFrom *vmtest.Node + pingTo *vmtest.Node + applyStep *vmtest.Step + verify *vmtest.Step + wait *vmtest.Step + ping *vmtest.Step + } + phases := []*phase{ + {name: "rotate (LocalAPI), b → a", pingFrom: b, pingTo: a, rotate: func() { env.RotateDiscoKey(b) }}, + {name: "rotate (LocalAPI), a → b", pingFrom: a, pingTo: b, rotate: func() { env.RotateDiscoKey(b) }}, + {name: "restart, b → a", pingFrom: b, pingTo: a, rotate: func() { env.RestartTailscaled(b) }}, + {name: "restart, a → b", pingFrom: a, pingTo: b, rotate: func() { env.RestartTailscaled(b) }}, + } + + pingABStep := env.AddStep("Ping a → b TSMP (establish tunnel)") + for _, p := range phases { + p.applyStep = env.AddStep("Apply: " + p.name) + p.verify = env.AddStep("Verify b: same node key, new disco key (" + p.name + ")") + p.wait = env.AddStep("Wait for a to see b's new disco key (" + p.name + ")") + p.ping = env.AddStep("Ping " + p.pingFrom.Name() + " → " + p.pingTo.Name() + " TSMP (" + p.name + ")") + } + + env.Start() + + pingABStep.Begin() + if err := env.Ping(a, b, tailcfg.PingTSMP, 30*time.Second); err != nil { + pingABStep.End(err) + t.Fatal(err) + } + pingABStep.End(nil) + + bStInitial := env.Status(b) + bNodeKey := bStInitial.Self.PublicKey + cs := env.ControlServer() + bCtlNode := cs.Node(bNodeKey) + if bCtlNode == nil { + t.Fatalf("control server has no node for b's key %v", bNodeKey) + } + prevDisco := bCtlNode.DiscoKey + if prevDisco.IsZero() { + t.Fatalf("control server has no disco key for b before rotation") + } + t.Logf("[b] initial: nodekey=%s discokey=%s", bNodeKey.ShortString(), prevDisco.ShortString()) + + for _, p := range phases { + p.applyStep.Begin() + p.rotate() + p.applyStep.End(nil) + prevDisco = checkDiscoRotated(t, env, a, b, p.pingFrom, p.pingTo, bNodeKey, prevDisco, p.name, + p.verify, p.wait, p.ping) + } +} + +// checkDiscoRotated verifies that after some action that should have rotated +// b's disco key, control has learned the new key, b's node key is unchanged, +// a's local view picks up the new disco key, and pingFrom can ping pingTo +// (TSMP) within the budget. It returns b's new disco key and fatals on +// failure. +// +// The TSMP ping budget is 15 seconds rather than the few hundred ms it +// ought to take. See the top-level test docstring for a full breakdown: +// it has to absorb LazyWG's trim+re-add for the rotate-a→b phase (~10s) +// and wireguard-go's RekeyTimeout retry for the SIGKILL+restart phases +// (~5s). Tighten this once both are addressed. +func checkDiscoRotated(t *testing.T, env *vmtest.Env, a, b, pingFrom, pingTo *vmtest.Node, bNodeKey key.NodePublic, oldDisco key.DiscoPublic, label string, verifyStep, waitStep, pingStep *vmtest.Step) key.DiscoPublic { + t.Helper() + cs := env.ControlServer() + + verifyStep.Begin() + bSt := env.Status(b) + if got := bSt.Self.PublicKey; got != bNodeKey { + err := fmt.Errorf("[%s] b's node key changed: %v -> %v", label, bNodeKey, got) + verifyStep.End(err) + t.Fatal(err) + } + var newDisco key.DiscoPublic + if err := tstest.WaitFor(15*time.Second, func() error { + n := cs.Node(bNodeKey) + if n == nil { + return fmt.Errorf("control server has no node for b") + } + if n.DiscoKey.IsZero() || n.DiscoKey == oldDisco { + return fmt.Errorf("control still has old disco key %v for b", n.DiscoKey) + } + newDisco = n.DiscoKey + return nil + }); err != nil { + verifyStep.End(err) + t.Fatalf("[%s] %v", label, err) + } + t.Logf("[b] after %s: nodekey=%s discokey=%s", label, bNodeKey.ShortString(), newDisco.ShortString()) + verifyStep.End(nil) + + waitStep.Begin() + if err := tstest.WaitFor(30*time.Second, func() error { + d, ok, err := env.PeerDiscoKey(a, bNodeKey) + if err != nil { + return err + } + if !ok { + return fmt.Errorf("a doesn't yet have b in its status") + } + if d != newDisco { + return fmt.Errorf("a still sees b's old disco %v, want %v", d.ShortString(), newDisco.ShortString()) + } + return nil + }); err != nil { + waitStep.End(err) + env.DumpStatus(a) + t.Fatalf("[%s] %v", label, err) + } + waitStep.End(nil) + + pingStep.Begin() + t0 := time.Now() + if err := env.Ping(pingFrom, pingTo, tailcfg.PingTSMP, 15*time.Second); err != nil { + pingStep.End(err) + env.DumpStatus(a) + env.DumpStatus(b) + t.Fatalf("[%s] %v", label, err) + } + t.Logf("[%s] ping %s -> %s succeeded in %v", label, pingFrom.Name(), pingTo.Name(), time.Since(t0).Round(100*time.Millisecond)) + pingStep.End(nil) + return newDisco +} + +// TestMullvadExitNode verifies that a Tailscale client whose netmap contains +// a plain-WireGuard exit node (the way Mullvad exit nodes are wired up by +// the control plane) can route internet traffic through it, with the source +// IP rewritten to the per-client Mullvad-assigned address. +// +// Topology: +// +// client (Tailscale, gokrazy) — clientNet (EasyNAT) WAN 1.0.0.1 +// mullvad (Ubuntu, userspace WG) — mullvadNet (One2OneNAT) WAN 2.0.0.1 +// webserver (no Tailscale, gokrazy) — webNet (One2OneNAT) WAN 5.0.0.1 +// +// The mullvad VM impersonates a Mullvad WireGuard server. After boot, the +// test asks its TTA agent to bring up a userspace WireGuard interface (a +// real Linux TUN driven by wireguard-go) that pins the client's Tailscale +// node public key as its only allowed peer, sets up IP-forwarding + a +// MASQUERADE rule, and reports the WG server's freshly generated public +// key back. Userspace vs kernel WireGuard makes no difference on the wire +// — what's being tested is Tailscale's plain-WireGuard exit-node code +// path, not the kernel module. +// +// The test then injects a netmap peer with IsWireGuardOnly=true, +// AllowedIPs=[gw/32, 0.0.0.0/0, ::/0], the WG endpoint, and a per-client +// SelfNodeV4MasqAddrForThisPeer (the mock equivalent of the per-client IP +// Mullvad's API hands out at registration time). +// +// The webserver echoes the source IP it sees: +// - exit-node off: source is client's WAN (direct egress) +// - exit-node on: source is mullvad's WAN (egress via WG + MASQUERADE) +func TestMullvadExitNode(t *testing.T) { + env := vmtest.New(t) + + const ( + clientWAN = "1.0.0.1" + mullvadWAN = "2.0.0.1" + webWAN = "5.0.0.1" + ) + // Mullvad-side WG network. The client appears as clientMasqIP to + // mullvad's wg0; mullvad terminates the tunnel at gw. + var ( + mullvadWGNet = netip.MustParsePrefix("10.64.0.0/24") + gw = netip.MustParsePrefix("10.64.0.1/24") + clientMasq = netip.MustParsePrefix("10.64.0.2/32") + ) + const wgListenPort uint16 = 51820 + + clientNet := env.AddNetwork(clientWAN, "192.168.1.1/24", vnet.EasyNAT) + mullvadNet := env.AddNetwork(mullvadWAN, "192.168.2.1/24", vnet.One2OneNAT) + webNet := env.AddNetwork(webWAN, "192.168.5.1/24", vnet.One2OneNAT) + + client := env.AddNode("client", clientNet, vmtest.OS(vmtest.Gokrazy)) + mullvad := env.AddNode("mullvad", mullvadNet, + vmtest.OS(vmtest.Ubuntu2404), + vmtest.DontJoinTailnet()) + env.AddNode("webserver", webNet, + vmtest.OS(vmtest.Gokrazy), + vmtest.DontJoinTailnet(), + vmtest.WebServer(8080)) + + // Declare test-specific steps for the web UI. + wgUpStep := env.AddStep("Bring up Mullvad WG server") + injectStep := env.AddStep("Inject Mullvad netmap peer") + checkOff1Step := env.AddStep("HTTP GET (exit off)") + checkMullvadStep := env.AddStep("HTTP GET (exit=mullvad)") + checkOff2Step := env.AddStep("HTTP GET (exit off, again)") + + env.Start() + + // Bring up the WG server inside mullvad's TTA, pinning the client's + // Tailscale node public key as the sole allowed peer. + wgUpStep.Begin() + clientStatus := env.Status(client) + mullvadPub := env.BringUpMullvadWGServer(mullvad, + gw, wgListenPort, + clientStatus.Self.PublicKey, clientMasq, mullvadWGNet) + wgUpStep.End(nil) + + // Inject the mullvad node into the netmap as a plain-WireGuard exit + // node. This mirrors how the control plane describes Mullvad exit + // nodes to clients (see control/cmullvad in the closed repo): a + // peer with IsWireGuardOnly=true, an Endpoints entry pointing at + // the public WG host:port, and AllowedIPs covering both the gateway + // /32 and the 0.0.0.0/0+::/0 exit-node routes. + injectStep.Begin() + mullvadEndpoint := netip.AddrPortFrom(netip.MustParseAddr(mullvadWAN), wgListenPort) + gwHost := netip.PrefixFrom(gw.Addr(), gw.Addr().BitLen()) + mullvadNode := &tailcfg.Node{ + ID: 999_001, + StableID: "mullvad-test", + Name: "mullvad-test.fake-control.example.net.", + Key: mullvadPub, + MachineAuthorized: true, + IsWireGuardOnly: true, + Endpoints: []netip.AddrPort{mullvadEndpoint}, + Addresses: []netip.Prefix{gwHost}, + AllowedIPs: []netip.Prefix{ + gwHost, + netip.MustParsePrefix("0.0.0.0/0"), + netip.MustParsePrefix("::/0"), + }, + Hostinfo: (&tailcfg.Hostinfo{ + Hostname: "mullvad-test", + }).View(), + } + cs := env.ControlServer() + cs.UpdateNode(mullvadNode) + + // Set the per-peer source-IP masquerade. The control plane normally + // derives this from the Mullvad API's per-client registration; here + // we just pin it to the address mullvad's wg0 was told to accept. + cs.SetMasqueradeAddresses([]testcontrol.MasqueradePair{{ + Node: clientStatus.Self.PublicKey, + Peer: mullvadPub, + NodeMasqueradesAs: clientMasq.Addr(), + }}) + injectStep.End(nil) + + webURL := fmt.Sprintf("http://%s:8080/", webWAN) + check := func(step *vmtest.Step, label, wantSrc string) { + t.Helper() + step.Begin() + body := env.HTTPGet(client, webURL) + t.Logf("[%s] response: %s", label, body) + if !strings.Contains(body, "Hello world I am webserver") { + step.End(fmt.Errorf("[%s] unexpected webserver response: %q", label, body)) + t.Fatalf("[%s] unexpected webserver response: %q", label, body) + } + if !strings.Contains(body, "from "+wantSrc) { + step.End(fmt.Errorf("[%s] expected source %q in response, got %q", label, wantSrc, body)) + t.Fatalf("[%s] expected source %q in response, got %q", label, wantSrc, body) + } + step.End(nil) + } + + // Exit-node off: client routes 0.0.0.0/0 directly via its host stack, + // so the webserver sees client's WAN IP. + check(checkOff1Step, "exit-off", clientWAN) + + // Switch to the Mullvad WG-only peer as exit node. The client should + // now route 0.0.0.0/0 through the WG tunnel; mullvad MASQUERADEs to + // its WAN; the webserver sees the mullvad VM's WAN IP. + env.SetExitNodeIP(client, gw.Addr()) + check(checkMullvadStep, "exit-mullvad", mullvadWAN) + + // And back off again, to make sure the transition works in both + // directions. + env.SetExitNodeIP(client, netip.Addr{}) + check(checkOff2Step, "exit-off (again)", clientWAN) +} + +// TestCachedNetmapAfterRestart verifies that two nodes with netmap +// caching enabled (NodeAttrCacheNetworkMaps) can re-establish a direct +// WireGuard tunnel after both are restarted while the control server is +// unreachable. After restart the nodes must use only their on-disk cached +// netmaps to re-connect. +func TestCachedNetmapAfterRestart(t *testing.T) { + env := vmtest.New(t) + + aNet := env.AddNetwork("1.0.0.1", "192.168.1.1/24", vnet.EasyNAT) + bNet := env.AddNetwork("2.0.0.1", "192.168.2.1/24", vnet.EasyNAT) + + aNet.SetPostConnectControlBlackhole(true) + bNet.SetPostConnectControlBlackhole(true) + + a := env.AddNode("a", aNet, + vmtest.OS(vmtest.Gokrazy), + tailcfg.NodeCapMap{tailcfg.NodeAttrCacheNetworkMaps: nil}) + b := env.AddNode("b", bNet, + vmtest.OS(vmtest.Gokrazy), + tailcfg.NodeCapMap{tailcfg.NodeAttrCacheNetworkMaps: nil}) + + connectStep := env.AddStep("Establish initial TSMP tunnel") + cutControlStep := env.AddStep("Cut control server access") + restartStep := env.AddStep("Restart tailscaled on both nodes") + netmapCheckStep := env.AddStep("Check netmap loaded is cached") + pingStep := env.AddStep("Ping a → b TSMP (cached netmap, no control)") + + env.Start() + + connectStep.Begin() + if err := env.Ping(a, b, tailcfg.PingTSMP, 30*time.Second); err != nil { + connectStep.End(err) + t.Fatal(err) + } + connectStep.End(nil) + + cutControlStep.Begin() + aNet.PostConnectedToControl() + bNet.PostConnectedToControl() + env.ControlServer().SetOnMapRequest(func(nk key.NodePublic) { + panic(fmt.Sprintf("got connection from %v", nk)) + }) + cutControlStep.End(nil) + + restartStep.Begin() + env.RestartTailscaled(a) + env.RestartTailscaled(b) + restartStep.End(nil) + + netmapCheckStep.Begin() + for _, node := range []*vmtest.Node{a, b} { + nm, err := local.GetDebugResultJSON[netmap.NetworkMap](t.Context(), node.Agent().Client, "current-netmap") + if err != nil { + netmapCheckStep.End(fmt.Errorf("[%s] got err fetching netmap %q", node.Name(), err)) + t.Fatalf("[%s] got err fetching netmap %q", node.Name(), err) + } + if !nm.Cached { + netmapCheckStep.End(fmt.Errorf("[%s] expected netmap.Cached = true, got: %t", node.Name(), nm.Cached)) + t.Fatalf("[%s] expected netmap.Cached = true, got: %t", node.Name(), nm.Cached) + } + } + netmapCheckStep.End(nil) + + pingStep.Begin() + if err := env.Ping(a, b, tailcfg.PingTSMP, 30*time.Second); err != nil { + pingStep.End(err) + t.Fatal(err) + } + pingStep.End(nil) +} diff --git a/tstest/natlab/vmtest/web.go b/tstest/natlab/vmtest/web.go new file mode 100644 index 0000000000000..d512740e6c83d --- /dev/null +++ b/tstest/natlab/vmtest/web.go @@ -0,0 +1,209 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package vmtest + +import ( + "embed" + "flag" + "fmt" + "hash/crc32" + "html/template" + "io" + "io/fs" + "net" + "net/http" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/coder/websocket" + "github.com/robert-nix/ansihtml" +) + +var vmtestWeb = flag.String("vmtest-web", "", "listen address for vmtest web UI (e.g. :0, localhost:0, :8080)") + +//go:embed assets/*.html +var templatesSrc embed.FS + +//go:embed assets/*.css +var staticAssets embed.FS + +var tmpl = sync.OnceValue(func() *template.Template { + d, err := fs.Sub(templatesSrc, "assets") + if err != nil { + panic(fmt.Errorf("getting vmtest web templates subdir: %w", err)) + } + return template.Must(template.New("").Funcs(template.FuncMap{ + "formatDuration": formatDuration, + "ansi": ansiToHTML, + }).ParseFS(d, "*")) +}) + +// ansiToHTML converts a string with ANSI escape sequences to HTML with +// inline styles. Returns template.HTML so html/template doesn't double-escape it. +func ansiToHTML(s string) template.HTML { + return template.HTML(ansihtml.ConvertToHTML([]byte(s))) +} + +// formatDuration returns a human-readable duration like "1.2s" or "45.3s". +func formatDuration(d time.Duration) string { + if d < time.Second { + return fmt.Sprintf("%dms", d.Milliseconds()) + } + return fmt.Sprintf("%.1fs", d.Seconds()) +} + +// deterministicPort returns a deterministic port in the range [20000, 40000) +// based on the test name, so re-running the same test gets the same URL. +func deterministicPort(testName string) int { + return int(crc32.ChecksumIEEE([]byte(testName)))%20000 + 20000 +} + +// listenWeb listens on the given address. If the port is 0, it first tries a +// deterministic port based on the test name so re-runs get the same URL. +// Falls back to :0 (OS-assigned) on any listen error. +func (e *Env) listenWeb(addr string) (net.Listener, error) { + host, port, _ := net.SplitHostPort(addr) + if port == "0" { + detPort := deterministicPort(e.t.Name()) + detAddr := net.JoinHostPort(host, fmt.Sprintf("%d", detPort)) + if ln, err := net.Listen("tcp", detAddr); err == nil { + return ln, nil + } + // Deterministic port busy; fall back to OS-assigned. + } + return net.Listen("tcp", addr) +} + +// maybeStartWebServer starts the web UI if --vmtest-web is set. +// Called at the very top of Env.Start(), before compilation or image downloads. +func (e *Env) maybeStartWebServer() { + addr := *vmtestWeb + if addr == "" { + return + } + + ln, err := e.listenWeb(addr) + if err != nil { + e.t.Fatalf("vmtest-web listen: %v", err) + } + e.t.Cleanup(func() { ln.Close() }) + + actualAddr := ln.Addr().(*net.TCPAddr) + + host, _, _ := net.SplitHostPort(addr) + if host == "" || host == "0.0.0.0" || host == "::" { + hostname, err := os.Hostname() + if err != nil { + hostname = "localhost" + } + e.t.Logf("Status at http://%s:%d/", hostname, actualAddr.Port) + } else { + e.t.Logf("Status at http://%s/", actualAddr.String()) + } + + mux := http.NewServeMux() + mux.HandleFunc("GET /", e.serveIndex) + mux.HandleFunc("GET /ws", e.serveWebSocket) + mux.HandleFunc("GET /screenshot/{node}", e.serveScreenshot) + mux.HandleFunc("GET /style.css", serveStaticAsset("style.css")) + + srv := &http.Server{Handler: mux} + go srv.Serve(ln) + e.t.Cleanup(func() { srv.Close() }) +} + +func serveStaticAsset(name string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if !strings.HasSuffix(name, ".css") { + http.Error(w, "not found", http.StatusNotFound) + return + } + w.Header().Set("Content-Type", "text/css") + f, err := staticAssets.Open(filepath.Join("assets", name)) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer f.Close() + io.Copy(w, f) + } +} + +func (e *Env) serveIndex(w http.ResponseWriter, r *http.Request) { + type indexData struct { + TestName string + TestStatus *TestStatus + Steps []*Step + Nodes []NodeStatus + } + + data := indexData{ + TestName: e.t.Name(), + TestStatus: e.testStatus, + Steps: e.Steps(), + } + for _, n := range e.nodes { + data.Nodes = append(data.Nodes, e.getNodeStatus(n.name)) + } + + w.Header().Set("Content-Type", "text/html; charset=utf-8") + if err := tmpl().ExecuteTemplate(w, "index.html", data); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +// serveScreenshot proxies a full-resolution screenshot from the Host.app +// screenshot server. Returns raw JPEG with no HTML wrapper. +func (e *Env) serveScreenshot(w http.ResponseWriter, r *http.Request) { + name := r.PathValue("node") + port := e.nodeScreenshotPort(name) + if port == 0 { + http.Error(w, "no screenshot server for node", http.StatusNotFound) + return + } + resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d/screenshot?full=1", port)) + if err != nil { + http.Error(w, err.Error(), http.StatusBadGateway) + return + } + defer resp.Body.Close() + w.Header().Set("Content-Type", "image/jpeg") + io.Copy(w, resp.Body) +} + +func (e *Env) serveWebSocket(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, nil) + if err != nil { + return + } + defer conn.CloseNow() + wsCtx := conn.CloseRead(r.Context()) + + sub := e.eventBus.Subscribe() + defer sub.Close() + + for { + select { + case <-wsCtx.Done(): + return + case <-sub.Done(): + return + case ev := <-sub.Events(): + msg, err := conn.Writer(r.Context(), websocket.MessageText) + if err != nil { + return + } + if err := tmpl().ExecuteTemplate(msg, "event.html", ev); err != nil { + msg.Close() + return + } + if err := msg.Close(); err != nil { + return + } + } + } +} diff --git a/tstest/natlab/vnet/conf.go b/tstest/natlab/vnet/conf.go index 3f83e35c09ba3..7cfd0e38cb621 100644 --- a/tstest/natlab/vnet/conf.go +++ b/tstest/natlab/vnet/conf.go @@ -5,6 +5,7 @@ package vnet import ( "cmp" + "context" "fmt" "iter" "net/netip" @@ -14,6 +15,7 @@ import ( "github.com/google/gopacket/layers" "github.com/google/gopacket/pcapgo" + "tailscale.com/tailcfg" "tailscale.com/types/logger" "tailscale.com/util/must" "tailscale.com/util/set" @@ -71,6 +73,13 @@ func nodeMac(n int) MAC { return MAC{0x52, 0xcc, 0xcc, 0xcc, 0xcc, byte(n)} } +// nodeNICMac returns the MAC for the nicIdx-th secondary NIC (1-indexed) of node n. +// Primary NICs (index 0) use nodeMac. Secondary NICs use a different scheme: +// 52:cc:cc:cc:KK:NN where KK is the NIC index and NN is the node number. +func nodeNICMac(nodeNum, nicIdx int) MAC { + return MAC{0x52, 0xcc, 0xcc, 0xcc, byte(nicIdx), byte(nodeNum)} +} + func routerMac(n int) MAC { // 52=TS then 0xee for 'etwork return MAC{0x52, 0xee, 0xee, 0xee, 0xee, byte(n)} @@ -114,6 +123,12 @@ func (c *Config) AddNode(opts ...any) *Node { switch o { case HostFirewall: n.hostFW = true + case RotateDisco: + n.rotateDisco = true + case PreICMPPing: + n.preICMPPing = true + case DontJoinTailnet: + n.dontJoinTailnet = true case VerboseSyslog: n.verboseSyslog = true default: @@ -123,6 +138,8 @@ func (c *Config) AddNode(opts ...any) *Node { } case MAC: n.mac = o + case tailcfg.NodeCapMap: + n.capMap = o default: if n.err == nil { n.err = fmt.Errorf("unknown AddNode option type %T", o) @@ -136,8 +153,11 @@ func (c *Config) AddNode(opts ...any) *Node { type NodeOption string const ( - HostFirewall NodeOption = "HostFirewall" - VerboseSyslog NodeOption = "VerboseSyslog" + HostFirewall NodeOption = "HostFirewall" + RotateDisco NodeOption = "RotateDisco" + PreICMPPing NodeOption = "PreICMPPing" + DontJoinTailnet NodeOption = "DontJoinTailnet" + VerboseSyslog NodeOption = "VerboseSyslog" ) // TailscaledEnv is а option that can be passed to Config.AddNode @@ -197,13 +217,18 @@ func (c *Config) AddNetwork(opts ...any) *Network { // Node is the configuration of a node in the virtual network. type Node struct { - err error - num int // 1-based node number - n *node // nil until NewServer called - - env []TailscaledEnv - hostFW bool - verboseSyslog bool + err error + num int // 1-based node number + n *node // nil until NewServer called + client *NodeAgentClient + + env []TailscaledEnv + hostFW bool + rotateDisco bool + preICMPPing bool + verboseSyslog bool + dontJoinTailnet bool + capMap tailcfg.NodeCapMap // TODO(bradfitz): this is halfway converted to supporting multiple NICs // but not done. We need a MAC-per-Network. @@ -222,11 +247,31 @@ func (n *Node) String() string { return fmt.Sprintf("node%d", n.num) } -// MAC returns the MAC address of the node. +// MAC returns the MAC address of the node's primary NIC. func (n *Node) MAC() MAC { return n.mac } +// NumNICs returns the number of network interfaces on the node +// (one per network the node is on). +func (n *Node) NumNICs() int { + return len(n.nets) +} + +// NICMac returns the MAC address for the i-th NIC (0-indexed). +// NIC 0 is the primary NIC (same as MAC()). NIC 1+ are extra NICs. +func (n *Node) NICMac(i int) MAC { + if i == 0 { + return n.mac + } + return nodeNICMac(n.num, i) +} + +// Networks returns the list of networks this node is on. +func (n *Node) Networks() []*Network { + return n.nets +} + func (n *Node) Env() []TailscaledEnv { return n.env } @@ -243,6 +288,46 @@ func (n *Node) SetVerboseSyslog(v bool) { n.verboseSyslog = v } +func (n *Node) SetClient(c *NodeAgentClient) { + n.client = c +} + +// PostConnectedToControl should be called after the clients have connected to +// control to modify the client behaviour after getting the network maps. +// Currently, the only implemented behavior is rotating disco keys. +func (n *Node) PostConnectedToControl(ctx context.Context) error { + if n.rotateDisco { + if err := n.client.DebugAction(ctx, "rotate-disco-key"); err != nil { + return err + } + } + return nil +} + +// PreICMPPing reports whether node should send an ICMP Ping sent before +// the disco ping. This is important for the nodes having rotated their +// disco keys while control is down. Disco pings deliberately does not +// trigger a TSMPDiscoKeyAdvertisement, making the need for other traffic (here +// simlulated as an ICMP ping) needed first. Any traffic could trigger this key +// exchange, the ICMP Ping is used as a handy existing way of sending some +// non-disco traffic. +func (n *Node) PreICMPPing() bool { + return n.preICMPPing +} + +// ShouldJoinTailnet reports whether node should join the test tailnet. Machines in +// the virtual universe that aren't on the tailnet are useful for testing that +// Tailscale does not break connectivity to resources outside the tailnet. +func (n *Node) ShouldJoinTailnet() bool { + return !n.dontJoinTailnet +} + +// WantCapMap returns the [tailcfg.NodeCapMap] that control should send down to +// this node, if any. +func (n *Node) WantCapMap() tailcfg.NodeCapMap { + return n.capMap +} + // IsV6Only reports whether this node is only connected to IPv6 networks. func (n *Node) IsV6Only() bool { for _, net := range n.nets { @@ -258,6 +343,26 @@ func (n *Node) IsV6Only() bool { return false } +// LanIP returns the node's LAN IPv4 address on the given network. +// It requires the [Server] to have been initialized (i.e., [New] was called). +// Returns an invalid addr if the node has no IP on that network. +func (n *Node) LanIP(net *Network) netip.Addr { + if n.n == nil { + return netip.Addr{} + } + for i, nn := range n.nets { + if nn == net { + if i == 0 { + return n.n.lanIP + } + if i-1 < len(n.n.extraNICs) { + return n.n.extraNICs[i-1].lanIP + } + } + } + return netip.Addr{} +} + // Network returns the first network this node is connected to, // or nil if none. func (n *Node) Network() *Network { @@ -275,10 +380,12 @@ type Network struct { wanIP6 netip.Prefix // global unicast router in host bits; CIDR is /64 delegated to LAN - wanIP4 netip.Addr // IPv4 WAN IP, if any - lanIP4 netip.Prefix - nodes []*Node - breakWAN4 bool // whether to break WAN IPv4 connectivity + wanIP4 netip.Addr // IPv4 WAN IP, if any + lanIP4 netip.Prefix + nodes []*Node + breakWAN4 bool // whether to break WAN IPv4 connectivity + postConnectBlackholeControl bool // whether to break control connectivity after nodes have connected + network *network svcs set.Set[NetworkService] @@ -310,6 +417,12 @@ func (n *Network) SetBlackholedIPv4(v bool) { n.breakWAN4 = v } +// SetPostConnectControlBlackhole sets whether the network should blackhole all +// traffic to the control server after the clients have connected. +func (n *Network) SetPostConnectControlBlackhole(v bool) { + n.postConnectBlackholeControl = v +} + func (n *Network) CanV4() bool { return n.lanIP4.IsValid() || n.wanIP4.IsValid() } @@ -325,6 +438,13 @@ func (n *Network) CanTakeMoreNodes() bool { return len(n.nodes) < 150 } +// PostConnectedToControl should be called after the clients have connected to +// the control server to modify network behaviors. Currently the only +// implemented behavior is to conditionally blackhole traffic to control. +func (n *Network) PostConnectedToControl() { + n.network.SetControlBlackholed(n.postConnectBlackholeControl) +} + // NetworkService is a service that can be added to a network. type NetworkService string @@ -390,6 +510,8 @@ func (s *Server) initFromConfig(c *Config) error { } netOfConf[conf] = n s.networks.Add(n) + + conf.network = n if conf.wanIP4.IsValid() { if conf.wanIP4.Is6() { return fmt.Errorf("invalid IPv6 address in wanIP") @@ -421,10 +543,11 @@ func (s *Server) initFromConfig(c *Config) error { if conf.err != nil { return conf.err } + primaryNet := netOfConf[conf.Network()] n := &node{ num: conf.num, mac: conf.mac, - net: netOfConf[conf.Network()], + net: primaryNet, verboseSyslog: conf.VerboseSyslog(), } n.interfaceID = must.Get(s.pcapWriter.AddInterface(pcapgo.NgInterface{ @@ -438,16 +561,50 @@ func (s *Server) initFromConfig(c *Config) error { s.nodes = append(s.nodes, n) s.nodeByMAC[n.mac] = n - if n.net.v4 { + if n.net != nil && n.net.v4 { // Allocate a lanIP for the node. Use the network's CIDR and use final // octet 101 (for first node), 102, etc. The node number comes from the - // last octent of the MAC address (0-based) + // last octet of the MAC address (0-based) ip4 := n.net.lanIP4.Addr().As4() ip4[3] = 100 + n.mac[5] n.lanIP = netip.AddrFrom4(ip4) n.net.nodesByIP4[n.lanIP] = n } - n.net.nodesByMAC[n.mac] = n + if n.net != nil { + n.net.nodesByMAC[n.mac] = n + } + + // Set up extra NICs for multi-homed nodes (nodes on more than one network). + for nicIdx, confNet := range conf.nets[1:] { + extraNet := netOfConf[confNet] + if extraNet == nil { + continue + } + mac := nodeNICMac(conf.num, nicIdx+1) + nic := nodeNIC{ + mac: mac, + net: extraNet, + } + nic.interfaceID = must.Get(s.pcapWriter.AddInterface(pcapgo.NgInterface{ + Name: fmt.Sprintf("%s-nic%d", n.String(), nicIdx+1), + LinkType: layers.LinkTypeEthernet, + })) + // Allocate a lanIP for the node. Use the network's CIDR and use final + // octet 101 (for first node), 102, etc. The node number comes from the + // last octet of the MAC address (0-based) + if extraNet.v4 { + ip4 := extraNet.lanIP4.Addr().As4() + ip4[3] = 100 + mac[5] + nic.lanIP = netip.AddrFrom4(ip4) + extraNet.nodesByIP4[nic.lanIP] = n + } + extraNet.nodesByMAC[mac] = n + if _, ok := s.nodeByMAC[mac]; ok { + return fmt.Errorf("two nodes have the same MAC %v", mac) + } + s.nodeByMAC[mac] = n + n.extraNICs = append(n.extraNICs, nic) + } } // Now that nodes are populated, set up NAT: diff --git a/tstest/natlab/vnet/vip.go b/tstest/natlab/vnet/vip.go index 9d7aa56a3d2a0..07b64f54c0615 100644 --- a/tstest/natlab/vnet/vip.go +++ b/tstest/natlab/vnet/vip.go @@ -19,6 +19,8 @@ var ( fakeDERP2 = newVIP("derp2.tailscale", "33.4.0.2") // 3340=DERP; 2=derp 2 fakeLogCatcher = newVIP("log.tailscale.com", 4) fakeSyslog = newVIP("syslog.tailscale", 9) + fakeCloudInit = newVIP("cloud-init.tailscale", 5) // serves cloud-init metadata/userdata per node + fakeFiles = newVIP("files.tailscale", 6) // serves binary files (tta, tailscale, tailscaled) to VMs ) type virtualIP struct { @@ -31,6 +33,13 @@ func (v virtualIP) Match(a netip.Addr) bool { return v.v4 == a.Unmap() || v.v6 == a } +// TestDriverIPv4 returns the IPv4 address of the test driver VIP (52.52.0.2). +// TTA agents dial this IP on port TestDriverPort to connect to the test harness. +func TestDriverIPv4() netip.Addr { return fakeTestAgent.v4 } + +// TestDriverPort is the port the test driver listens on. +const TestDriverPort = 8008 + // FakeDNSIPv4 returns the fake DNS IPv4 address. func FakeDNSIPv4() netip.Addr { return fakeDNS.v4 } diff --git a/tstest/natlab/vnet/vnet.go b/tstest/natlab/vnet/vnet.go index 357fe213c8c28..1c28c2c5dd584 100644 --- a/tstest/natlab/vnet/vnet.go +++ b/tstest/natlab/vnet/vnet.go @@ -30,6 +30,7 @@ import ( "net/netip" "os/exec" "strconv" + "strings" "sync" "sync/atomic" "time" @@ -267,10 +268,13 @@ func (n *network) handleIPPacketFromGvisor(ipRaw []byte) { n.logf("gvisor: serialize error: %v", err) return } - if nw, ok := n.writers.Load(node.mac); ok { + // Use the MAC address for this specific network (important for multi-NIC nodes + // where the primary MAC may be on a different network). + mac := node.macForNet(n) + if nw, ok := n.writers.Load(mac); ok { nw.write(resPkt) } else { - n.logf("gvisor write: no writeFunc for %v", node.mac) + n.logf("gvisor write: no writeFunc for %v (node %v on net %v)", mac, node, n.mac) } } @@ -290,6 +294,24 @@ func stringifyTEI(tei stack.TransportEndpointID) string { return fmt.Sprintf("%s -> %s", remoteHostPort, localHostPort) } +// vipNameOf returns the VIP name for the given IP, or "" if it's not a VIP. +func vipNameOf(ip netip.Addr) string { + for _, v := range vips { + if v.Match(ip) { + return v.name + } + } + return "" +} + +// nodeNameOf returns the node's name for the given IP on this network, or "" if unknown. +func (n *network) nodeNameOf(ip netip.Addr) string { + if node, ok := n.nodeByIP(ip); ok { + return node.String() + } + return "" +} + func (n *network) acceptTCP(r *tcp.ForwarderRequest) { reqDetails := r.ID() @@ -301,7 +323,17 @@ func (n *network) acceptTCP(r *tcp.ForwarderRequest) { return } - log.Printf("vnet-AcceptTCP: %v", stringifyTEI(reqDetails)) + // Annotate the log with node/VIP names for readability. + srcHP := net.JoinHostPort(clientRemoteIP.String(), strconv.Itoa(int(reqDetails.RemotePort))) + srcStr := srcHP + if name := n.nodeNameOf(clientRemoteIP); name != "" { + srcStr = fmt.Sprintf("%s (%s)", srcHP, name) + } + dstStr := net.JoinHostPort(destIP.String(), strconv.Itoa(int(destPort))) + if name := vipNameOf(destIP); name != "" { + dstStr = fmt.Sprintf("%s (%s)", dstStr, name) + } + log.Printf("vnet-AcceptTCP: %s -> %s", srcStr, dstStr) var wq waiter.Queue ep, err := r.CreateEndpoint(&wq) @@ -320,7 +352,7 @@ func (n *network) acceptTCP(r *tcp.ForwarderRequest) { return } - if destPort == 8008 && fakeTestAgent.Match(destIP) { + if destPort == TestDriverPort && fakeTestAgent.Match(destIP) { node, ok := n.nodeByIP(clientRemoteIP) if !ok { n.logf("unknown client IP %v trying to connect to test driver", clientRemoteIP) @@ -371,6 +403,22 @@ func (n *network) acceptTCP(r *tcp.ForwarderRequest) { return } + if destPort == 80 && fakeCloudInit.Match(destIP) { + r.Complete(false) + tc := gonet.NewTCPConn(&wq, ep) + hs := &http.Server{Handler: n.s.cloudInitHandler()} + go hs.Serve(netutil.NewOneConnListener(tc, nil)) + return + } + + if destPort == 80 && fakeFiles.Match(destIP) { + r.Complete(false) + tc := gonet.NewTCPConn(&wq, ep) + hs := &http.Server{Handler: n.s.fileServerHandler()} + go hs.Serve(netutil.NewOneConnListener(tc, nil)) + return + } + var targetDial string if n.s.derpIPs.Contains(destIP) { targetDial = destIP.String() + ":" + strconv.Itoa(int(destPort)) @@ -506,23 +554,24 @@ func (nw networkWriter) write(b []byte) { } type network struct { - s *Server - num int // 1-based - mac MAC // of router - portmap bool - lanInterfaceID int - wanInterfaceID int - v4 bool // network supports IPv4 - v6 bool // network support IPv6 - wanIP6 netip.Prefix // router's WAN IPv6, if any, as a /64. - wanIP4 netip.Addr // router's LAN IPv4, if any - lanIP4 netip.Prefix // router's LAN IP + CIDR (e.g. 192.168.2.1/24) - breakWAN4 bool // break WAN IPv4 connectivity - latency time.Duration // latency applied to interface writes - lossRate float64 // probability of dropping a packet (0.0 to 1.0) - nodesByIP4 map[netip.Addr]*node // by LAN IPv4 - nodesByMAC map[MAC]*node - logf func(format string, args ...any) + s *Server + num int // 1-based + mac MAC // of router + portmap bool + lanInterfaceID int + wanInterfaceID int + v4 bool // network supports IPv4 + v6 bool // network support IPv6 + wanIP6 netip.Prefix // router's WAN IPv6, if any, as a /64. + wanIP4 netip.Addr // router's LAN IPv4, if any + lanIP4 netip.Prefix // router's LAN IP + CIDR (e.g. 192.168.2.1/24) + breakWAN4 bool // break WAN IPv4 connectivity + blackholeControl bool // blackhole control connectivity + latency time.Duration // latency applied to interface writes + lossRate float64 // probability of dropping a packet (0.0 to 1.0) + nodesByIP4 map[netip.Addr]*node // by LAN IPv4 + nodesByMAC map[MAC]*node + logf func(format string, args ...any) ns *stack.Stack linkEP *channel.Endpoint @@ -572,12 +621,29 @@ func (n *network) MACOfIP(ip netip.Addr) (_ MAC, ok bool) { if n.lanIP4.Addr() == ip { return n.mac, true } - if n, ok := n.nodesByIP4[ip]; ok { - return n.mac, true + if node, ok := n.nodesByIP4[ip]; ok { + // Use the MAC for this specific network (important for multi-NIC nodes + // where the primary MAC may be on a different network). + return node.macForNet(n), true } return MAC{}, false } +// SetControlBlackholed sets whether traffic to control should be blackholed for the +// network. +func (n *network) SetControlBlackholed(v bool) { + n.blackholeControl = v +} + +// nodeNIC represents a single network interface on a node. +// For multi-homed nodes, additional NICs beyond the primary are stored in node.extraNICs. +type nodeNIC struct { + mac MAC + net *network + lanIP netip.Addr + interfaceID int +} + type node struct { mac MAC num int // 1-based node number @@ -586,6 +652,8 @@ type node struct { lanIP netip.Addr // must be in net.lanIP prefix + unique in net verboseSyslog bool + extraNICs []nodeNIC // secondary NICs for multi-homed nodes + // logMu guards logBuf. // TODO(bradfitz): conditionally write these out to separate files at the end? // Currently they only hold logcatcher logs. @@ -594,6 +662,35 @@ type node struct { logCatcherWrites int } +// netForMAC returns the network associated with the given MAC address on this node. +// It checks the primary NIC first, then any extra NICs. +func (n *node) netForMAC(mac MAC) *network { + if mac == n.mac { + return n.net + } + for _, nic := range n.extraNICs { + if nic.mac == mac { + return nic.net + } + } + return nil +} + +// macForNet returns the MAC address that this node uses on the given network. +// For the primary network, this is node.mac. For secondary networks, it's the +// extra NIC's MAC. +func (n *node) macForNet(net *network) MAC { + if n.net == net { + return n.mac + } + for _, nic := range n.extraNICs { + if nic.net == net { + return nic.mac + } + } + return n.mac // fallback to primary +} + // String returns the string "nodeN" where N is the 1-based node number. func (n *node) String() string { return fmt.Sprintf("node%d", n.num) @@ -650,6 +747,14 @@ type Server struct { agentConnWaiter map[*node]chan<- struct{} // signaled after added to set agentConns set.Set[*agentConn] // not keyed by node; should be small/cheap enough to scan all agentDialer map[*node]netx.DialFunc + gotFirstPacket map[MAC]chan struct{} // closed on first packet from each MAC + + cloudInitData map[int]*CloudInitData // node num → cloud-init config + fileContents map[string][]byte // filename → file bytes + + // onDHCPEvent, if non-nil, is called when DHCP messages are processed. + // Parameters are: source MAC, node number, DHCP message type, assigned IP. + onDHCPEvent func(nodeMAC MAC, nodeNum int, msgType layers.DHCPMsgType, assignedIP netip.Addr) } func (s *Server) logf(format string, args ...any) { @@ -664,6 +769,13 @@ func (s *Server) SetLoggerForTest(logf func(format string, args ...any)) { s.optLogf = logf } +// SetDHCPCallback registers a function to be called when DHCP messages are +// processed. The callback receives the source MAC, node number, DHCP message +// type (Discover, Offer, Request, Ack), and the assigned IP address. +func (s *Server) SetDHCPCallback(fn func(MAC, int, layers.DHCPMsgType, netip.Addr)) { + s.onDHCPEvent = fn +} + var derpMap = &tailcfg.DERPMap{ Regions: map[int]*tailcfg.DERPRegion{ 1: { @@ -725,6 +837,10 @@ func New(c *Config) (*Server, error) { if err := s.initFromConfig(c); err != nil { return nil, err } + s.gotFirstPacket = make(map[MAC]chan struct{}) + for mac := range s.nodeByMAC { + s.gotFirstPacket[mac] = make(chan struct{}) + } for n := range s.networks { if err := n.initStack(); err != nil { return nil, fmt.Errorf("newServer: initStack: %v", err) @@ -734,6 +850,96 @@ func New(c *Config) (*Server, error) { return s, nil } +// ControlServer returns the test control server used by this vnet. +func (s *Server) ControlServer() *testcontrol.Server { + return s.control +} + +// CloudInitData holds the cloud-init configuration for a node. +type CloudInitData struct { + MetaData string + UserData string + NetworkConfig string // optional; if set, served as network-config +} + +// SetCloudInitData registers cloud-init configuration for the given node number. +// This data is served via the cloud-init.tailscale VIP when the VM boots. +func (s *Server) SetCloudInitData(nodeNum int, data *CloudInitData) { + s.mu.Lock() + defer s.mu.Unlock() + mak.Set(&s.cloudInitData, nodeNum, data) +} + +// RegisterFile registers a file to be served by the files.tailscale VIP. +// The path is the URL path (e.g., "tta" is served at http://files.tailscale/tta). +func (s *Server) RegisterFile(path string, data []byte) { + s.mu.Lock() + defer s.mu.Unlock() + mak.Set(&s.fileContents, path, data) +} + +// cloudInitHandler returns an HTTP handler that serves cloud-init +// meta-data and user-data for VMs that boot with +// ds=nocloud;s=http://cloud-init.tailscale/node-N/. +func (s *Server) cloudInitHandler() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Parse node number from URL path like "/node-2/meta-data" + path := strings.TrimPrefix(r.URL.Path, "/") + parts := strings.SplitN(path, "/", 2) + if len(parts) != 2 { + http.Error(w, "bad path", http.StatusNotFound) + return + } + nodeNum := 0 + if _, err := fmt.Sscanf(parts[0], "node-%d", &nodeNum); err != nil { + http.Error(w, "bad node number", http.StatusNotFound) + return + } + s.mu.Lock() + data := s.cloudInitData[nodeNum] + s.mu.Unlock() + if data == nil { + http.Error(w, "no cloud-init data for node", http.StatusNotFound) + return + } + switch parts[1] { + case "meta-data": + w.Header().Set("Content-Type", "text/yaml") + io.WriteString(w, data.MetaData) + case "user-data": + w.Header().Set("Content-Type", "text/yaml") + io.WriteString(w, data.UserData) + case "network-config": + if data.NetworkConfig == "" { + http.Error(w, "not found", http.StatusNotFound) + return + } + w.Header().Set("Content-Type", "text/yaml") + io.WriteString(w, data.NetworkConfig) + default: + http.Error(w, "not found", http.StatusNotFound) + } + }) +} + +// fileServerHandler returns an HTTP handler that serves files registered +// via RegisterFile. Files are served at http://files.tailscale/. +func (s *Server) fileServerHandler() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + path := strings.TrimPrefix(r.URL.Path, "/") + s.mu.Lock() + data, ok := s.fileContents[path] + s.mu.Unlock() + if !ok { + http.Error(w, "not found", http.StatusNotFound) + return + } + w.Header().Set("Content-Type", "application/octet-stream") + w.Header().Set("Content-Length", strconv.Itoa(len(data))) + w.Write(data) + }) +} + func (s *Server) Close() { if shutdown := s.shuttingDown.Swap(true); !shutdown { s.shutdownCancel() @@ -742,6 +948,22 @@ func (s *Server) Close() { s.wg.Wait() } +// AwaitFirstPacket waits until the first ethernet frame is received from the +// given MAC address, indicating the VM has booted far enough to send network +// traffic. It returns an error if the context expires first. +func (s *Server) AwaitFirstPacket(ctx context.Context, mac MAC) error { + ch, ok := s.gotFirstPacket[mac] + if !ok { + return fmt.Errorf("unknown MAC %v", mac) + } + select { + case <-ch: + return nil + case <-ctx.Done(): + return fmt.Errorf("no network packets received from %v: %w", mac, ctx.Err()) + } +} + // MACs returns the MAC addresses of the configured nodes. func (s *Server) MACs() iter.Seq[MAC] { return maps.Keys(s.nodeByMAC) @@ -855,8 +1077,7 @@ func (s *Server) ServeUnixConn(uc *net.UnixConn, proto Protocol) { n, addr, err := uc.ReadFromUnix(buf) raddr = addr if err != nil { - if s.shutdownCtx.Err() != nil { - // Return without logging. + if s.shutdownCtx.Err() != nil || errors.Is(err, net.ErrClosed) { return } s.logf("ReadFromUnix: %#v", err) @@ -898,9 +1119,21 @@ func (s *Server) ServeUnixConn(uc *net.UnixConn, proto Protocol) { } if !didReg[srcMAC] { didReg[srcMAC] = true + if ch, ok := s.gotFirstPacket[srcMAC]; ok { + select { + case <-ch: // already closed + default: + close(ch) + } + } + srcNet := srcNode.netForMAC(srcMAC) + if srcNet == nil { + s.logf("[conn %p] node %v has no network for MAC %v", c.uc, srcNode, srcMAC) + continue + } s.logf("[conn %p] Registering writer for MAC %v, node %v", c.uc, srcMAC, srcNode.lanIP) - srcNode.net.registerWriter(srcMAC, c) - defer srcNode.net.unregisterWriter(srcMAC) + srcNet.registerWriter(srcMAC, c) + defer srcNet.unregisterWriter(srcMAC) } if err := s.handleEthernetFrameFromVM(packetRaw); err != nil { @@ -923,16 +1156,38 @@ func (s *Server) handleEthernetFrameFromVM(packetRaw []byte) error { return fmt.Errorf("got frame from unknown MAC %v", srcMAC) } + srcNet := srcNode.netForMAC(srcMAC) + if srcNet == nil { + return fmt.Errorf("node %v has no network for MAC %v", srcNode, srcMAC) + } + must.Do(s.pcapWriter.WritePacket(gopacket.CaptureInfo{ Timestamp: time.Now(), CaptureLength: len(packetRaw), Length: len(packetRaw), InterfaceIndex: srcNode.interfaceID, }, packetRaw)) - srcNode.net.HandleEthernetPacket(ep) + srcNet.HandleEthernetPacket(ep) return nil } +// routeTCPPacket forwards a TCP packet to the network owning the +// destination IP (looked up by WAN IP). Used for inter-network TCP +// forwarding so guest VM TCP stacks talk end-to-end through vnet's +// packet-level NAT. +func (s *Server) routeTCPPacket(tp TCPPacket) { + dstIP := tp.Dst.Addr() + netw, ok := s.networkByWAN.Lookup(dstIP) + if !ok { + if dstIP.IsPrivate() { + return + } + log.Printf("no network to route TCP packet for %v", tp.Dst) + return + } + netw.HandleTCPPacket(tp) +} + func (s *Server) routeUDPPacket(up UDPPacket) { // Find which network owns this based on the destination IP // and all the known networks' wan IPs. @@ -1169,6 +1424,65 @@ func (n *network) nodeByIP(ip netip.Addr) (node *node, ok bool) { return node, ok } +// HandleTCPPacket handles a TCP packet arriving from the simulated +// internet, addressed to the network's WAN IP. It NATs the destination +// back to a LAN node and writes the rewritten packet onto the LAN. +func (n *network) HandleTCPPacket(p TCPPacket) { + buf, err := n.serializedTCPPacket(p.Src, p.Dst, p.TCP, nil) + if err != nil { + n.logf("serializing TCP packet: %v", err) + return + } + n.s.pcapWriter.WritePacket(gopacket.CaptureInfo{ + Timestamp: time.Now(), + CaptureLength: len(buf), + Length: len(buf), + InterfaceIndex: n.wanInterfaceID, + }, buf) + if p.Dst.Addr().Is4() && n.breakWAN4 { + return + } + dst := n.doNATIn(p.Src, p.Dst) + if !dst.IsValid() { + n.logf("Warning: NAT dropped TCP packet; no mapping for %v=>%v", p.Src, p.Dst) + return + } + p.Dst = dst + buf, err = n.serializedTCPPacket(p.Src, p.Dst, p.TCP, nil) + if err != nil { + n.logf("serializing TCP packet: %v", err) + return + } + n.s.pcapWriter.WritePacket(gopacket.CaptureInfo{ + Timestamp: time.Now(), + CaptureLength: len(buf), + Length: len(buf), + InterfaceIndex: n.lanInterfaceID, + }, buf) + n.WriteTCPPacketNoNAT(p) +} + +// WriteTCPPacketNoNAT writes a TCP packet to the network without doing +// any NAT translation. The src/dst in p must already be in their final +// form for the LAN. +func (n *network) WriteTCPPacketNoNAT(p TCPPacket) { + node, ok := n.nodeByIP(p.Dst.Addr()) + if !ok { + n.logf("no node for dest IP %v in TCP packet %v=>%v", p.Dst.Addr(), p.Src, p.Dst) + return + } + eth := &layers.Ethernet{ + SrcMAC: n.mac.HWAddr(), + DstMAC: node.macForNet(n).HWAddr(), + } + ethRaw, err := n.serializedTCPPacket(p.Src, p.Dst, p.TCP, eth) + if err != nil { + n.logf("serializing TCP packet: %v", err) + return + } + n.writeEth(ethRaw) +} + // WriteUDPPacketNoNAT writes a UDP packet to the network, without // doing any NAT translation. // @@ -1184,8 +1498,8 @@ func (n *network) WriteUDPPacketNoNAT(p UDPPacket) { } eth := &layers.Ethernet{ - SrcMAC: n.mac.HWAddr(), // of gateway - DstMAC: node.mac.HWAddr(), + SrcMAC: n.mac.HWAddr(), // of gateway; on the specific network + DstMAC: node.macForNet(n).HWAddr(), // use the MAC for this network } ethRaw, err := n.serializedUDPPacket(src, dst, p.Payload, eth) if err != nil { @@ -1218,6 +1532,27 @@ func mkIPLayer(proto layers.IPProtocol, src, dst netip.Addr) serializableNetwork panic("invalid src IP") } +// serializedTCPPacket serializes a TCP packet with the given src/dst, +// using the provided TCP layer (its flags, seq/ack, window, options, +// and payload are preserved; only the src/dst ports are overwritten). +// +// If eth is non-nil, it is used as the Ethernet layer, otherwise the +// Ethernet layer is omitted. +func (n *network) serializedTCPPacket(src, dst netip.AddrPort, tcp *layers.TCP, eth *layers.Ethernet) ([]byte, error) { + ip := mkIPLayer(layers.IPProtocolTCP, src.Addr(), dst.Addr()) + // Copy the TCP layer with new ports and a zeroed checksum so + // gopacket recomputes it against the new IP pseudo-header. + newTCP := *tcp + newTCP.SrcPort = layers.TCPPort(src.Port()) + newTCP.DstPort = layers.TCPPort(dst.Port()) + newTCP.Checksum = 0 + payload := gopacket.Payload(tcp.Payload) + if eth == nil { + return mkPacket(ip, &newTCP, payload) + } + return mkPacket(eth, ip, &newTCP, payload) +} + // serializedUDPPacket serializes a UDP packet with the given source and // destination IP:port pairs, and payload. // @@ -1263,7 +1598,8 @@ func (n *network) HandleEthernetPacketForRouter(ep EthernetPacket) { } if toForward && n.s.shouldInterceptTCP(packet) { - if flow.dst.Is4() && n.breakWAN4 { + if (flow.dst.Is4() && n.breakWAN4) || + (n.blackholeControl && fakeControl.Match(flow.dst)) { // Blackhole the packet. return } @@ -1288,14 +1624,81 @@ func (n *network) HandleEthernetPacketForRouter(ep EthernetPacket) { return } + // Inter-network TCP forwarding: a guest VM is sending TCP to another + // simulated network's WAN IP. Apply egress NAT (rewriting src) and + // hand the packet off to the destination network for ingress NAT and + // LAN delivery, so the two guest TCP stacks talk end-to-end. + if toForward && flow.dst.Is4() { + if tcp, ok := packet.Layer(layers.LayerTypeTCP).(*layers.TCP); ok { + if _, ok := n.s.networkByWAN.Lookup(flow.dst); ok { + n.handleTCPPacketForRouter(tcp, flow) + return + } + } + } + if flow.src.Is6() && flow.src.IsLinkLocalUnicast() && !flow.dst.IsLinkLocalUnicast() { // Don't log. return } + if toForward { + // Traffic to destinations we don't handle (e.g. VMs trying to reach + // the real internet for NTP, package updates, etc). Expected; drop silently. + return + } + n.logf("router got unknown packet: %v", packet) } +// handleTCPPacketForRouter handles a TCP packet from a LAN node that +// targets another simulated network's WAN IP. It rewrites src via the +// local NAT, then routes the packet to the destination network where +// HandleTCPPacket rewrites dst and delivers it to the LAN. +func (n *network) handleTCPPacketForRouter(tcp *layers.TCP, flow ipSrcDst) { + if flow.dst.Is4() && n.breakWAN4 { + return + } + src := netip.AddrPortFrom(flow.src, uint16(tcp.SrcPort)) + dst := netip.AddrPortFrom(flow.dst, uint16(tcp.DstPort)) + + buf, err := n.serializedTCPPacket(src, dst, tcp, nil) + if err != nil { + n.logf("serializing TCP packet: %v", err) + return + } + n.s.pcapWriter.WritePacket(gopacket.CaptureInfo{ + Timestamp: time.Now(), + CaptureLength: len(buf), + Length: len(buf), + InterfaceIndex: n.lanInterfaceID, + }, buf) + + lanSrc := src + src = n.doNATOut(src, dst) + if !src.IsValid() { + n.logf("warning: NAT dropped TCP packet; no NAT out mapping for %v=>%v", lanSrc, dst) + return + } + buf, err = n.serializedTCPPacket(src, dst, tcp, nil) + if err != nil { + n.logf("serializing TCP packet: %v", err) + return + } + n.s.pcapWriter.WritePacket(gopacket.CaptureInfo{ + Timestamp: time.Now(), + CaptureLength: len(buf), + Length: len(buf), + InterfaceIndex: n.wanInterfaceID, + }, buf) + + n.s.routeTCPPacket(TCPPacket{ + Src: src, + Dst: dst, + TCP: tcp, + }) +} + func (n *network) handleUDPPacketForRouter(ep EthernetPacket, udp *layers.UDP, toForward bool, flow ipSrcDst) { packet := ep.gp srcIP, dstIP := flow.src, flow.dst @@ -1523,12 +1926,28 @@ func (s *Server) createDHCPResponse(request gopacket.Packet) ([]byte, error) { log.Printf("DHCP request from unknown node %v; ignoring", srcMAC) return nil, nil } - gwIP := node.net.lanIP4.Addr() + // Use the network associated with this MAC (important for multi-NIC nodes). + srcNet := node.netForMAC(srcMAC) + if srcNet == nil { + log.Printf("DHCP request from MAC %v with no associated network; ignoring", srcMAC) + return nil, nil + } + gwIP := srcNet.lanIP4.Addr() - ipLayer := request.Layer(layers.LayerTypeIPv4).(*layers.IPv4) udpLayer := request.Layer(layers.LayerTypeUDP).(*layers.UDP) dhcpLayer := request.Layer(layers.LayerTypeDHCPv4).(*layers.DHCPv4) + // Determine the client's LAN IP for this specific NIC. + clientIP := node.lanIP + if srcMAC != node.mac { + for _, nic := range node.extraNICs { + if nic.mac == srcMAC { + clientIP = nic.lanIP + break + } + } + } + response := &layers.DHCPv4{ Operation: layers.DHCPOpReply, HardwareType: layers.LinkTypeEthernet, @@ -1536,7 +1955,7 @@ func (s *Server) createDHCPResponse(request gopacket.Packet) ([]byte, error) { Xid: dhcpLayer.Xid, ClientHWAddr: dhcpLayer.ClientHWAddr, Flags: dhcpLayer.Flags, - YourClientIP: node.lanIP.AsSlice(), + YourClientIP: clientIP.AsSlice(), Options: []layers.DHCPOption{ { Type: layers.DHCPOptServerID, @@ -1554,11 +1973,37 @@ func (s *Server) createDHCPResponse(request gopacket.Packet) ([]byte, error) { } switch msgType { case layers.DHCPMsgTypeDiscover: - response.Options = append(response.Options, layers.DHCPOption{ - Type: layers.DHCPOptMessageType, - Data: []byte{byte(layers.DHCPMsgTypeOffer)}, - Length: 1, - }) + response.Options = append(response.Options, + layers.DHCPOption{ + Type: layers.DHCPOptMessageType, + Data: []byte{byte(layers.DHCPMsgTypeOffer)}, + Length: 1, + }, + layers.DHCPOption{ + Type: layers.DHCPOptLeaseTime, + Data: binary.BigEndian.AppendUint32(nil, 3600), + Length: 4, + }, + layers.DHCPOption{ + Type: layers.DHCPOptSubnetMask, + Data: net.CIDRMask(srcNet.lanIP4.Bits(), 32), + Length: 4, + }, + layers.DHCPOption{ + Type: layers.DHCPOptRouter, + Data: gwIP.AsSlice(), + Length: 4, + }, + layers.DHCPOption{ + Type: layers.DHCPOptDNS, + Data: fakeDNS.v4.AsSlice(), + Length: 4, + }, + ) + if s.onDHCPEvent != nil { + s.onDHCPEvent(srcMAC, node.num, layers.DHCPMsgTypeDiscover, clientIP) + s.onDHCPEvent(srcMAC, node.num, layers.DHCPMsgTypeOffer, clientIP) + } case layers.DHCPMsgTypeRequest: response.Options = append(response.Options, layers.DHCPOption{ @@ -1583,10 +2028,14 @@ func (s *Server) createDHCPResponse(request gopacket.Packet) ([]byte, error) { }, layers.DHCPOption{ Type: layers.DHCPOptSubnetMask, - Data: net.CIDRMask(node.net.lanIP4.Bits(), 32), + Data: net.CIDRMask(srcNet.lanIP4.Bits(), 32), Length: 4, }, ) + if s.onDHCPEvent != nil { + s.onDHCPEvent(srcMAC, node.num, layers.DHCPMsgTypeRequest, clientIP) + s.onDHCPEvent(srcMAC, node.num, layers.DHCPMsgTypeAck, clientIP) + } } eth := &layers.Ethernet{ @@ -1596,8 +2045,8 @@ func (s *Server) createDHCPResponse(request gopacket.Packet) ([]byte, error) { } ip := &layers.IPv4{ Protocol: layers.IPProtocolUDP, - SrcIP: ipLayer.DstIP, - DstIP: ipLayer.SrcIP, + SrcIP: gwIP.AsSlice(), + DstIP: net.IPv4bcast, // DHCP responses are broadcast when client has no IP yet } udp := &layers.UDP{ SrcPort: udpLayer.DstPort, @@ -1645,7 +2094,7 @@ func (s *Server) shouldInterceptTCP(pkt gopacket.Packet) bool { } if tcp.DstPort == 80 || tcp.DstPort == 443 { - for _, v := range []virtualIP{fakeControl, fakeDERP1, fakeDERP2, fakeLogCatcher} { + for _, v := range []virtualIP{fakeControl, fakeDERP1, fakeDERP2, fakeLogCatcher, fakeCloudInit, fakeFiles} { if v.Match(flow.dst) { return true } @@ -1657,7 +2106,7 @@ func (s *Server) shouldInterceptTCP(pkt gopacket.Packet) bool { return true } } - if tcp.DstPort == 8008 && fakeTestAgent.Match(flow.dst) { + if tcp.DstPort == TestDriverPort && fakeTestAgent.Match(flow.dst) { // Connection from cmd/tta. return true } @@ -1909,7 +2358,7 @@ func (n *network) doPortMap(src netip.Addr, dstLANPort, wantExtPort uint16, sec } } - for try := 0; try < 20_000; try++ { + for range 20_000 { if wanAP.Port() > 0 && !n.natTable.IsPublicPortUsed(wanAP) { mak.Set(&n.portMap, wanAP, portMapping{ dst: dst, @@ -2047,6 +2496,17 @@ type UDPPacket struct { Payload []byte // everything after UDP header } +// TCPPacket is a TCP packet flowing through vnet's NAT, used for +// packet-level TCP forwarding between simulated networks. Unlike UDP +// (which only needs ports + payload), TCP carries flags, sequence +// numbers, and options that must be preserved end-to-end so the guest +// VM kernels' TCP state machines stay in sync. +type TCPPacket struct { + Src netip.AddrPort + Dst netip.AddrPort + TCP *layers.TCP // full parsed TCP layer (header + options + payload) +} + func (s *Server) WriteStartingBanner(w io.Writer) { fmt.Fprintf(w, "vnet serving clients:\n") diff --git a/tstest/natlab/vnet/vnet_test.go b/tstest/natlab/vnet/vnet_test.go index 93f208c29ca0a..9d7c78c453b11 100644 --- a/tstest/natlab/vnet/vnet_test.go +++ b/tstest/natlab/vnet/vnet_test.go @@ -120,7 +120,7 @@ func TestPacketSideEffects(t *testing.T) { check: all( numPkts(2), // DHCP discover broadcast to node2 also, and the DHCP reply from router pktSubstr("SrcMAC=52:cc:cc:cc:cc:01 DstMAC=ff:ff:ff:ff:ff:ff"), - pktSubstr("Options=[Option(ServerID:192.168.0.1), Option(MessageType:Offer)]}"), + pktSubstr("Option(ServerID:192.168.0.1), Option(MessageType:Offer), Option(LeaseTime:3600)"), ), }, { diff --git a/tstest/reflect.go b/tstest/reflect.go index 22903e7e9fca2..4ba1f96c39666 100644 --- a/tstest/reflect.go +++ b/tstest/reflect.go @@ -8,8 +8,6 @@ import ( "reflect" "testing" "time" - - "tailscale.com/types/ptr" ) // IsZeroable is the interface for things with an IsZero method. @@ -60,7 +58,7 @@ func CheckIsZero[T IsZeroable](t testing.TB, nonzeroValues map[reflect.Type]any) case timeType: return reflect.ValueOf(time.Unix(1704067200, 0)) case timePtrType: - return reflect.ValueOf(ptr.To(time.Unix(1704067200, 0))) + return reflect.ValueOf(new(time.Unix(1704067200, 0))) } switch ty.Kind() { diff --git a/tstest/resource_test.go b/tstest/resource_test.go index ecef91cf60b08..4c3b68eeb4d80 100644 --- a/tstest/resource_test.go +++ b/tstest/resource_test.go @@ -22,7 +22,7 @@ func TestPrintGoroutines(t *testing.T) { want: "goroutine profile: total 0", }, { - name: "single goroutine", + name: "single-goroutine", in: `goroutine profile: total 1 1 @ 0x47bc0e 0x458e57 0x847587 0x483da1 # 0x847586 database/sql.(*DB).connectionOpener+0x86 database/sql/sql.go:1261 @@ -34,7 +34,7 @@ func TestPrintGoroutines(t *testing.T) { `, }, { - name: "multiple goroutines sorted", + name: "multiple-goroutines-sorted", in: `goroutine profile: total 14 7 @ 0x47bc0e 0x413705 0x4132b2 0x10fda4d 0x483da1 # 0x10fda4c github.com/user/pkg.RoutineA+0x16c pkg/a.go:443 @@ -70,7 +70,7 @@ func TestDiffPprofGoroutines(t *testing.T) { want string }{ { - name: "no difference", + name: "no-difference", x: `goroutine profile: total 1 1 @ 0x47bc0e 0x458e57 0x847587 0x483da1 # 0x847586 database/sql.(*DB).connectionOpener+0x86 database/sql/sql.go:1261`, @@ -81,7 +81,7 @@ func TestDiffPprofGoroutines(t *testing.T) { want: "", }, { - name: "different counts", + name: "different-counts", x: `goroutine profile: total 1 1 @ 0x47bc0e 0x458e57 0x847587 0x483da1 # 0x847586 database/sql.(*DB).connectionOpener+0x86 database/sql/sql.go:1261 @@ -99,7 +99,7 @@ func TestDiffPprofGoroutines(t *testing.T) { `, }, { - name: "new goroutine", + name: "new-goroutine", x: `goroutine profile: total 1 1 @ 0x47bc0e 0x458e57 0x847587 0x483da1 # 0x847586 database/sql.(*DB).connectionOpener+0x86 database/sql/sql.go:1261 @@ -119,7 +119,7 @@ func TestDiffPprofGoroutines(t *testing.T) { `, }, { - name: "removed goroutine", + name: "removed-goroutine", x: `goroutine profile: total 2 1 @ 0x47bc0e 0x458e57 0x847587 0x483da1 # 0x847586 database/sql.(*DB).connectionOpener+0x86 database/sql/sql.go:1261 @@ -139,7 +139,7 @@ func TestDiffPprofGoroutines(t *testing.T) { `, }, { - name: "removed many goroutine", + name: "removed-many-goroutine", x: `goroutine profile: total 2 1 @ 0x47bc0e 0x458e57 0x847587 0x483da1 # 0x847586 database/sql.(*DB).connectionOpener+0x86 database/sql/sql.go:1261 @@ -159,13 +159,13 @@ func TestDiffPprofGoroutines(t *testing.T) { `, }, { - name: "invalid input x", + name: "invalid-input-x", x: "invalid", y: "goroutine profile: total 0\n", want: "- invalid\n+ goroutine profile: total 0\n", }, { - name: "invalid input y", + name: "invalid-input-y", x: "goroutine profile: total 0\n", y: "invalid", want: "- goroutine profile: total 0\n+ invalid\n", @@ -193,13 +193,13 @@ func TestParseGoroutines(t *testing.T) { wantCount int }{ { - name: "empty profile", + name: "empty-profile", in: "goroutine profile: total 0\n", wantHeader: "goroutine profile: total 0", wantCount: 0, }, { - name: "single goroutine", + name: "single-goroutine", in: `goroutine profile: total 1 1 @ 0x47bc0e 0x458e57 0x847587 0x483da1 # 0x847586 database/sql.(*DB).connectionOpener+0x86 database/sql/sql.go:1261 @@ -208,7 +208,7 @@ func TestParseGoroutines(t *testing.T) { wantCount: 1, }, { - name: "multiple goroutines", + name: "multiple-goroutines", in: `goroutine profile: total 14 7 @ 0x47bc0e 0x413705 0x4132b2 0x10fda4d 0x483da1 # 0x10fda4c github.com/user/pkg.RoutineA+0x16c pkg/a.go:443 @@ -220,7 +220,7 @@ func TestParseGoroutines(t *testing.T) { wantCount: 2, }, { - name: "invalid format", + name: "invalid-format", in: "invalid", wantHeader: "invalid", }, @@ -245,7 +245,7 @@ func TestParseGoroutines(t *testing.T) { t.Errorf("sort field has different number of words: got %d, want %d", len(sorted), len(original)) continue } - for i := 0; i < len(original); i++ { + for i := range original { if original[i] != sorted[len(sorted)-1-i] { t.Errorf("sort field word mismatch at position %d: got %q, want %q", i, sorted[len(sorted)-1-i], original[i]) } diff --git a/tstest/tailmac/Makefile b/tstest/tailmac/Makefile index b87e44ed1c49d..303f72c1f5d26 100644 --- a/tstest/tailmac/Makefile +++ b/tstest/tailmac/Makefile @@ -5,12 +5,12 @@ endif .PHONY: tailmac tailmac: - xcodebuild -scheme tailmac -destination 'platform=macOS,arch=arm64' -derivedDataPath build -configuration Release build | $(XCPRETTIFIER) + set -o pipefail && xcodebuild -scheme tailmac -destination 'platform=macOS,arch=arm64' -derivedDataPath build -configuration Release build | $(XCPRETTIFIER) cp -r ./build/Build/Products/Release/tailmac ./bin/tailmac .PHONY: host host: - xcodebuild -scheme host -destination 'platform=macOS,arch=arm64' -derivedDataPath build -configuration Release build | $(XCPRETTIFIER) + set -o pipefail && xcodebuild -scheme host -destination 'platform=macOS,arch=arm64' -derivedDataPath build -configuration Release build | $(XCPRETTIFIER) cp -r ./build/Build/Products/Release/Host.app ./bin/Host.app .PHONY: clean diff --git a/tstest/tailmac/README.md b/tstest/tailmac/README.md index a8b9f2598dde3..6c62d24318119 100644 --- a/tstest/tailmac/README.md +++ b/tstest/tailmac/README.md @@ -53,7 +53,7 @@ All vm images, restore images, block device files, save states, and other suppor Each vm gets its own directory. These can be archived for posterity to preserve a particular image and/or state. The mere existence of a directory containing all of the required files in ~/VM.bundle is sufficient for tailmac to -be able to see and run it. ~/VM.bundle and it's contents *is* tailmac's state. No other state is maintained elsewhere. +be able to see and run it. ~/VM.bundle and its contents *is* tailmac's state. No other state is maintained elsewhere. Each vm has its own custom configuration which can be modified while the vm is idle. It's simple JSON - you may modify this directly, or using 'tailmac configure'. diff --git a/tstest/tailmac/Swift/Common/Config.swift b/tstest/tailmac/Swift/Common/Config.swift index 53d7680205a00..53281628a5826 100644 --- a/tstest/tailmac/Swift/Common/Config.swift +++ b/tstest/tailmac/Swift/Common/Config.swift @@ -103,10 +103,10 @@ class Config: Codable { } -// The VM Bundle URL holds the restore image and a set of VM images -// By default, VM's are persisted at ~/VM.bundle +// The VM Bundle URL holds the restore image and a set of VM images. +// VMs are stored under ~/.cache/tailscale/vmtest/macos/. var vmBundleURL: URL = { - let vmBundlePath = NSHomeDirectory() + "/VM.bundle/" + let vmBundlePath = NSHomeDirectory() + "/.cache/tailscale/vmtest/macos/" createDir(vmBundlePath) let bundleURL = URL(fileURLWithPath: vmBundlePath) return bundleURL diff --git a/tstest/tailmac/Swift/Common/TailMacConfigHelper.swift b/tstest/tailmac/Swift/Common/TailMacConfigHelper.swift index fc7f2d89dc0e2..562eae1fa1e2a 100644 --- a/tstest/tailmac/Swift/Common/TailMacConfigHelper.swift +++ b/tstest/tailmac/Swift/Common/TailMacConfigHelper.swift @@ -74,18 +74,31 @@ struct TailMacConfigHelper { return networkDevice } + /// Creates a NIC configuration connected to the vnet dgram socket. func createSocketNetworkDeviceConfiguration() -> VZVirtioNetworkDeviceConfiguration { let networkDevice = VZVirtioNetworkDeviceConfiguration() networkDevice.macAddress = VZMACAddress(string: config.mac)! + if let attachment = createDgramAttachment(serverSocket: config.serverSocket, clientID: config.vmID) { + networkDevice.attachment = attachment + } + return networkDevice + } - let socket = Darwin.socket(AF_UNIX, SOCK_DGRAM, 0) + /// Creates a NIC configuration with no attachment (disconnected). + /// The attachment can be hot-swapped later via VZNetworkDevice.attachment. + func createDisconnectedNetworkDeviceConfiguration() -> VZVirtioNetworkDeviceConfiguration { + let networkDevice = VZVirtioNetworkDeviceConfiguration() + networkDevice.macAddress = VZMACAddress(string: config.mac)! + // No attachment — NIC appears disconnected to the guest. + return networkDevice + } - // Outbound network packets - let serverSocket = config.serverSocket + /// Creates a dgram socket attachment for connecting to a vnet server. + /// Returns nil on error. + func createDgramAttachment(serverSocket: String, clientID: String) -> VZFileHandleNetworkDeviceAttachment? { + let socket = Darwin.socket(AF_UNIX, SOCK_DGRAM, 0) - // Inbound network packets - let clientSockId = config.vmID - let clientSocket = "/tmp/qemu-dgram-\(clientSockId).sock" + let clientSocket = "/tmp/qemu-dgram-\(clientID).sock" unlink(clientSocket) var clientAddr = sockaddr_un() @@ -102,7 +115,7 @@ struct TailMacConfigHelper { if bindRes == -1 { print("Error binding virtual network client socket - \(String(cString: strerror(errno)))") - return networkDevice + return nil } var serverAddr = sockaddr_un() @@ -118,20 +131,16 @@ struct TailMacConfigHelper { socklen_t(MemoryLayout.size)) if connectRes == -1 { - print("Error binding virtual network server socket - \(String(cString: strerror(errno)))") - return networkDevice + print("Error connecting to server socket \(serverSocket) - \(String(cString: strerror(errno)))") + return nil } print("Virtual if mac address is \(config.mac)") print("Client bound to \(clientSocket)") print("Connected to server at \(serverSocket)") - print("Socket fd is \(socket)") - let handle = FileHandle(fileDescriptor: socket) - let device = VZFileHandleNetworkDeviceAttachment(fileHandle: handle) - networkDevice.attachment = device - return networkDevice + return VZFileHandleNetworkDeviceAttachment(fileHandle: handle) } func createPointingDeviceConfiguration() -> VZPointingDeviceConfiguration { diff --git a/tstest/tailmac/Swift/Host/HostCli.swift b/tstest/tailmac/Swift/Host/HostCli.swift index 9c9ae6fa0476e..16711b2aab242 100644 --- a/tstest/tailmac/Swift/Host/HostCli.swift +++ b/tstest/tailmac/Swift/Host/HostCli.swift @@ -20,13 +20,291 @@ extension HostCli { struct Run: ParsableCommand { @Option var id: String @Option var share: String? + @Flag(help: "Run without GUI (for automated testing)") var headless: Bool = false + @Flag(help: "Create NIC with no attachment (for later hot-swap)") var disconnectedNic: Bool = false + @Flag(help: "Use NAT NIC instead of socket NIC (for snapshot prep)") var natNic: Bool = false + @Option(help: "Hot-swap NIC to this dgram socket path after boot/restore") var attachNetwork: String? + @Option(help: "Serve screenshots on this localhost port (0 = auto)") var screenshotPort: Int? + @Option(help: "Assign IP/mask/gw to guest via vsock (e.g. 192.168.1.2/255.255.255.0/192.168.1.1)") var assignIp: String? mutating func run() { config = Config(id) config.sharedDir = share print("Running vm with identifier \(id) and sharedDir \(share ?? "")") - _ = NSApplicationMain(CommandLine.argc, CommandLine.unsafeArgv) + + if headless { + let attachSocket = attachNetwork + let useNatNIC = natNic + let disconnected = !useNatNIC && (disconnectedNic || attachSocket != nil) + let wantScreenshots = screenshotPort != nil + let requestedPort = UInt16(screenshotPort ?? 0) + let ipConfig = assignIp + + // Set up SIGINT handler before entering the event loop. + // The dispatch source must be stored in a global to prevent ARC deallocation. + signal(SIGINT, SIG_IGN) + let sigintSource = DispatchSource.makeSignalSource(signal: SIGINT, queue: .main) + retainedSigintSource = sigintSource + + DispatchQueue.main.async { + let controller = VMController() + controller.createVirtualMachine(headless: true, disconnectedNIC: disconnected, natNIC: useNatNIC) + + // Start vsock listener for IP assignment. + // If --assign-ip is set, the listener replies with the IP config JSON. + // If not set (snapshot prep), it replies "wait" so TTA keeps polling. + if let ipCfg = ipConfig { + let parts = ipCfg.split(separator: "/") + if parts.count == 3 { + let response = "{\"ip\":\"\(parts[0])\",\"mask\":\"\(parts[1])\",\"gw\":\"\(parts[2])\"}" + controller.startIPConfigListener(response: response) + } + } else { + controller.startIPConfigListener(response: "wait") + } + + sigintSource.setEventHandler { + print("SIGINT received, disconnecting NIC and saving VM state...") + controller.disconnectNetwork() + controller.pauseAndSaveVirtualMachine { + print("VM state saved, exiting.") + Foundation.exit(0) + } + } + sigintSource.resume() + + // Set up screenshot HTTP server if requested. + // The window must be ordered on-screen for the window server + // to composite VZVirtualMachineView's content. We place it + // behind all other windows and make it tiny (1x1) so it's + // effectively invisible. + if wantScreenshots { + let vmView = VZVirtualMachineView() + vmView.virtualMachine = controller.virtualMachine + vmView.frame = NSRect(x: 0, y: 0, width: 1920, height: 1200) + + let window = NSWindow( + contentRect: NSRect(x: 0, y: 0, width: 1920, height: 1200), + styleMask: [.borderless], + backing: .buffered, + defer: false + ) + window.isReleasedWhenClosed = false + window.contentView = vmView + // Place behind all other windows so it's not visible to the user. + window.level = NSWindow.Level(rawValue: Int(CGWindowLevelForKey(.minimumWindow)) - 1) + window.orderFront(nil) + + startScreenshotServer(view: vmView, port: requestedPort) + } + + let doAttach = { + if let sock = attachSocket { + controller.attachNetwork(serverSocket: sock, clientID: config.vmID) + } + } + + let fileManager = FileManager.default + if fileManager.fileExists(atPath: config.saveFileURL.path) { + print("Restoring virtual machine state from \(config.saveFileURL)") + controller.restoreVirtualMachine() + doAttach() + } else { + print("Starting virtual machine") + controller.startVirtualMachine() + doAttach() + } + } + + if wantScreenshots { + // NSApp event loop needed for VZVirtualMachineView rendering. + let app = NSApplication.shared + app.setActivationPolicy(.accessory) + print("STARTING_NSAPP") + fflush(stdout) + app.run() + } else { + // Use dispatchMain() instead of RunLoop.main.run() so that + // GCD dispatch sources (like the SIGINT handler) are processed. + dispatchMain() + } + } else { + _ = NSApplicationMain(CommandLine.argc, CommandLine.unsafeArgv) + } + } + } +} + +// startScreenshotServer starts a localhost HTTP server that serves VM display +// screenshots on GET /screenshot as JPEG. The port is printed to stdout as +// "SCREENSHOT_PORT=" so the Go test harness can discover it. +var retainedSigintSource: DispatchSourceSignal? // prevent ARC deallocation +var screenshotServer: ScreenshotHTTPServer? // prevent GC + +func startScreenshotServer(view: NSView, port: UInt16) { + let server = ScreenshotHTTPServer(view: view) + screenshotServer = server + server.start(port: port) +} + +/// Minimal HTTP server that serves screenshots of a VZVirtualMachineView. +class ScreenshotHTTPServer: NSObject { + let view: NSView + var acceptSource: DispatchSourceRead? // prevent GC + + init(view: NSView) { + self.view = view + } + + private func log(_ msg: String) { + let s = msg + "\n" + FileHandle.standardError.write(Data(s.utf8)) + } + + func start(port: UInt16) { + let queue = DispatchQueue(label: "screenshot-server") + + let fd = socket(AF_INET, SOCK_STREAM, 0) + guard fd >= 0 else { + log("screenshot server: socket() failed") + return + } + var yes: Int32 = 1 + setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &yes, socklen_t(MemoryLayout.size)) + + var addr = sockaddr_in() + addr.sin_len = UInt8(MemoryLayout.size) + addr.sin_family = sa_family_t(AF_INET) + addr.sin_port = port.bigEndian + addr.sin_addr.s_addr = UInt32(0x7f000001).bigEndian // 127.0.0.1 + + let bindResult = withUnsafePointer(to: &addr) { ptr in + ptr.withMemoryRebound(to: sockaddr.self, capacity: 1) { sockPtr in + Darwin.bind(fd, sockPtr, socklen_t(MemoryLayout.size)) + } + } + guard bindResult == 0 else { + log("screenshot server: bind() failed: \(errno)") + close(fd) + return } + guard Darwin.listen(fd, 4) == 0 else { + log("screenshot server: listen() failed") + close(fd) + return + } + + var boundAddr = sockaddr_in() + var boundLen = socklen_t(MemoryLayout.size) + withUnsafeMutablePointer(to: &boundAddr) { ptr in + ptr.withMemoryRebound(to: sockaddr.self, capacity: 1) { sockPtr in + getsockname(fd, sockPtr, &boundLen) + } + } + let actualPort = UInt16(bigEndian: boundAddr.sin_port) + print("SCREENSHOT_PORT=\(actualPort)") + fflush(stdout) + + let source = DispatchSource.makeReadSource(fileDescriptor: fd, queue: queue) + source.setEventHandler { [self] in + let clientFd = accept(fd, nil, nil) + self.log("screenshot: accept fd=\(clientFd)") + guard clientFd >= 0 else { return } + self.handleConnection(clientFd) + } + source.setCancelHandler { close(fd) } + source.resume() + self.acceptSource = source + } + + private func handleConnection(_ fd: Int32) { + var buf = [UInt8](repeating: 0, count: 4096) + let n = read(fd, &buf, buf.count) + let requestLine = n > 0 ? String(bytes: buf[.. Data? { + guard let window = view.window else { + log("screenshot: no window") + return nil + } + + // Use CGWindowListCreateImage to capture the composited window content, + // which includes GPU-rendered layers like VZVirtualMachineView's Metal surface. + let windowID = CGWindowID(window.windowNumber) + guard let cgImage = CGWindowListCreateImage( + .null, + .optionIncludingWindow, + windowID, + [.boundsIgnoreFraming, .bestResolution] + ) else { + log("screenshot: CGWindowListCreateImage returned nil") + return nil + } + + if fullSize { + let bitmapRep = NSBitmapImageRep(cgImage: cgImage) + return bitmapRep.representation(using: .jpeg, properties: [.compressionFactor: 0.85]) + } + + // Resize to ~800px wide for thumbnails. + let targetWidth = 800 + let scale = Double(targetWidth) / Double(cgImage.width) + let targetHeight = Int(Double(cgImage.height) * scale) + + guard let ctx = CGContext( + data: nil, + width: targetWidth, + height: targetHeight, + bitsPerComponent: 8, + bytesPerRow: 0, + space: CGColorSpaceCreateDeviceRGB(), + bitmapInfo: CGImageAlphaInfo.premultipliedFirst.rawValue + ) else { + log("screenshot: CGContext creation failed") + return nil + } + ctx.interpolationQuality = .high + ctx.draw(cgImage, in: CGRect(x: 0, y: 0, width: targetWidth, height: targetHeight)) + + guard let resized = ctx.makeImage() else { + log("screenshot: makeImage failed") + return nil + } + + let bitmapRep = NSBitmapImageRep(cgImage: resized) + return bitmapRep.representation(using: .jpeg, properties: [.compressionFactor: 0.6]) } } diff --git a/tstest/tailmac/Swift/Host/VMController.swift b/tstest/tailmac/Swift/Host/VMController.swift index a19d7222e1e9e..c2014009a8ee9 100644 --- a/tstest/tailmac/Swift/Host/VMController.swift +++ b/tstest/tailmac/Swift/Host/VMController.swift @@ -81,7 +81,7 @@ class VMController: NSObject, VZVirtualMachineDelegate { return macPlatform } - func createVirtualMachine() { + func createVirtualMachine(headless: Bool = false, disconnectedNIC: Bool = false, natNIC: Bool = false) { let virtualMachineConfiguration = VZVirtualMachineConfiguration() virtualMachineConfiguration.platform = createMacPlaform() @@ -90,7 +90,21 @@ class VMController: NSObject, VZVirtualMachineDelegate { virtualMachineConfiguration.memorySize = helper.computeMemorySize() virtualMachineConfiguration.graphicsDevices = [helper.createGraphicsDeviceConfiguration()] virtualMachineConfiguration.storageDevices = [helper.createBlockDeviceConfiguration()] - virtualMachineConfiguration.networkDevices = [helper.createNetworkDeviceConfiguration(), helper.createSocketNetworkDeviceConfiguration()] + if headless { + if natNIC { + // NAT NIC for SSH access during snapshot preparation. + virtualMachineConfiguration.networkDevices = [helper.createNetworkDeviceConfiguration()] + } else if disconnectedNIC { + // Create a NIC with no attachment. The NIC exists in the hardware + // config (so saved state is compatible) but appears disconnected. + // Call attachNetwork() after restore to hot-swap the attachment. + virtualMachineConfiguration.networkDevices = [helper.createDisconnectedNetworkDeviceConfiguration()] + } else { + virtualMachineConfiguration.networkDevices = [helper.createSocketNetworkDeviceConfiguration()] + } + } else { + virtualMachineConfiguration.networkDevices = [helper.createNetworkDeviceConfiguration(), helper.createSocketNetworkDeviceConfiguration()] + } virtualMachineConfiguration.pointingDevices = [helper.createPointingDeviceConfiguration()] virtualMachineConfiguration.keyboards = [helper.createKeyboardConfiguration()] virtualMachineConfiguration.socketDevices = [helper.createSocketDeviceConfiguration()] @@ -109,6 +123,33 @@ class VMController: NSObject, VZVirtualMachineDelegate { virtualMachine.delegate = self } + /// Disconnect the NIC by setting its attachment to nil. + /// Call before saving state so the snapshot has no active link. + func disconnectNetwork() { + guard let nic = virtualMachine.networkDevices.first else { + print("disconnectNetwork: no network devices") + return + } + nic.attachment = nil + print("disconnectNetwork: NIC attachment set to nil") + } + + /// Hot-swap the NIC attachment on a running VM. The VM must have been + /// created with disconnectedNIC=true. After calling this, the guest + /// sees the link come up and does DHCP. + func attachNetwork(serverSocket: String, clientID: String) { + guard let nic = virtualMachine.networkDevices.first else { + print("attachNetwork: no network devices") + return + } + guard let attachment = helper.createDgramAttachment(serverSocket: serverSocket, clientID: clientID) else { + print("attachNetwork: failed to create attachment") + return + } + nic.attachment = attachment + print("attachNetwork: NIC attachment swapped to \(serverSocket)") + } + func startVirtualMachine() { virtualMachine.start(completionHandler: { (result) in @@ -130,6 +171,21 @@ class VMController: NSObject, VZVirtualMachineDelegate { } } + /// Start a vsock listener that tells the guest TTA agent what IP to configure. + /// If response is nil, the listener replies "wait" (snapshot prep mode). + func startIPConfigListener(response: String) { + guard let device = virtualMachine.socketDevices.first as? VZVirtioSocketDevice else { + print("startIPConfigListener: no socket device") + return + } + let listener = IPConfigListener(response: response) + retainedIPConfigListener = listener + let vsockListener = VZVirtioSocketListener() + vsockListener.delegate = listener + device.setSocketListener(vsockListener, forPort: 51011) + print("startIPConfigListener: listening on vsock port 51011") + } + func resumeVirtualMachine() { virtualMachine.resume(completionHandler: { (result) in if case let .failure(error) = result { @@ -184,3 +240,28 @@ class VMController: NSObject, VZVirtualMachineDelegate { exit(0) } } + +// Global to prevent ARC deallocation of the vsock listener. +var retainedIPConfigListener: IPConfigListener? + +/// Listens on vsock port 51011 for TTA connections and replies with +/// an IP configuration JSON string (or "wait" during snapshot prep). +class IPConfigListener: NSObject, VZVirtioSocketListenerDelegate { + let response: String + + init(response: String) { + self.response = response + } + + func listener(_ listener: VZVirtioSocketListener, + shouldAcceptNewConnection connection: VZVirtioSocketConnection, + from socketDevice: VZVirtioSocketDevice) -> Bool { + let fd = connection.fileDescriptor + let data = Array((response + "\n").utf8) + data.withUnsafeBufferPointer { buf in + _ = write(fd, buf.baseAddress!, buf.count) + } + connection.close() + return true + } +} diff --git a/tstest/tailmac/Swift/TailMac/TailMac.swift b/tstest/tailmac/Swift/TailMac/TailMac.swift index 3859b9b0b0aeb..2271d3bb29186 100644 --- a/tstest/tailmac/Swift/TailMac/TailMac.swift +++ b/tstest/tailmac/Swift/TailMac/TailMac.swift @@ -329,7 +329,7 @@ extension Tailmac { } } - dispatchMain() + RunLoop.main.run() } } } diff --git a/tstest/tstest.go b/tstest/tstest.go index 4e00fbaa38ae8..7e25ce8a03f35 100644 --- a/tstest/tstest.go +++ b/tstest/tstest.go @@ -20,8 +20,22 @@ import ( "tailscale.com/util/cibuild" ) +// AssertNotParallel asserts that t has not been marked as parallel. +// It panics (via t.Setenv) if t.Parallel has already been called. +// +// Use this when a test modifies package-level globals or other shared +// state that would be unsafe to modify concurrently with other tests. +func AssertNotParallel(t testing.TB) { + t.Helper() + t.Setenv("ASSERT_NOT_PARALLEL_TEST", "1") // panics if t.Parallel was called +} + // Replace replaces the value of target with val. // The old value is restored when the test ends. +// +// When target is a package-level variable, the caller should also call +// [AssertNotParallel] to ensure the test is not running in parallel with +// other tests that may access the same variable. func Replace[T any](t testing.TB, target *T, val T) { t.Helper() if target == nil { @@ -95,6 +109,14 @@ func Parallel(t *testing.T) { } } +// RequireRoot skips the test if the current user is not root. +func RequireRoot(tb testing.TB) { + tb.Helper() + if os.Getuid() != 0 { + tb.Skip("skipping test; requires root") + } +} + // SkipOnKernelVersions skips the test if the current // kernel version is in the specified list. func SkipOnKernelVersions(t testing.TB, issue string, versions ...string) { diff --git a/tstest/typewalk/typewalk.go b/tstest/typewalk/typewalk.go index f989b4c180394..dea87a8e927fc 100644 --- a/tstest/typewalk/typewalk.go +++ b/tstest/typewalk/typewalk.go @@ -54,14 +54,13 @@ func MatchingPaths(rt reflect.Type, match func(reflect.Type) bool) iter.Seq[Path return } switch t.Kind() { - case reflect.Ptr, reflect.Slice, reflect.Array: + case reflect.Pointer, reflect.Slice, reflect.Array: walk(t.Elem(), func(root reflect.Value) reflect.Value { v := getV(root) return v.Elem() }) case reflect.Struct: - for i := range t.NumField() { - sf := t.Field(i) + for sf := range t.Fields() { fieldName := sf.Name if fieldName == "_" { continue diff --git a/tsweb/debug_test.go b/tsweb/debug_test.go index b46a3a3f37c32..79c686b6b95d9 100644 --- a/tsweb/debug_test.go +++ b/tsweb/debug_test.go @@ -8,7 +8,9 @@ import ( "io" "net/http" "net/http/httptest" + "net/netip" "runtime" + "slices" "strings" "testing" ) @@ -206,3 +208,82 @@ func ExampleDebugHandler_Section() { fmt.Fprintf(w, "%#v", r) }) } + +func TestParseTrustedCIDRs(t *testing.T) { + tests := []struct { + name string + raw string + want []netip.Prefix + }{ + { + name: "empty", + raw: "", + want: nil, + }, + { + name: "single_v4", + raw: "10.0.0.0/8", + want: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, + }, + { + name: "multiple", + raw: "10.0.0.0/8,172.16.0.0/12", + want: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/8"), + netip.MustParsePrefix("172.16.0.0/12"), + }, + }, + { + name: "spaces_trimmed", + raw: " 10.0.0.0/8 , 192.168.0.0/16 ", + want: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/8"), + netip.MustParsePrefix("192.168.0.0/16"), + }, + }, + { + name: "ipv6", + raw: "fd00::/8", + want: []netip.Prefix{netip.MustParsePrefix("fd00::/8")}, + }, + { + name: "trailing_comma", + raw: "10.0.0.0/8,", + want: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := parseTrustedCIDRs(tt.raw) + if !slices.Equal(got, tt.want) { + t.Fatalf("got %v, want %v", got, tt.want) + } + }) + } +} + +func TestAllowDebugAccessTrustedCIDRContains(t *testing.T) { + // Verify that parsed CIDRs correctly match/reject IPs. + cidrs := parseTrustedCIDRs("10.0.0.0/8,192.168.1.0/24,fd00::/8") + + tests := []struct { + ip string + want bool + }{ + {"10.1.2.3", true}, + {"10.255.255.255", true}, + {"192.168.1.50", true}, + {"192.168.2.1", false}, + {"172.16.0.1", false}, + {"8.8.8.8", false}, + {"fd00::1", true}, + {"fe80::1", false}, + } + for _, tt := range tests { + ip := netip.MustParseAddr(tt.ip) + if got := cidrsContain(cidrs, ip); got != tt.want { + t.Errorf("CIDRs contain %s = %v, want %v", tt.ip, got, tt.want) + } + } +} diff --git a/tsweb/tsweb.go b/tsweb/tsweb.go index f464e7af2141e..101512b89b7f8 100644 --- a/tsweb/tsweb.go +++ b/tsweb/tsweb.go @@ -13,6 +13,8 @@ import ( "expvar" "fmt" "io" + "log" + "maps" "net" "net/http" "net/netip" @@ -53,6 +55,50 @@ func IsProd443(addr string) bool { return port == "443" || port == "https" } +// debugTrustedCIDRs is the envknob for TS_DEBUG_TRUSTED_CIDRS, a +// comma-separated list of CIDR ranges (e.g. "10.0.0.0/8,172.16.0.0/12") +// whose source IPs are allowed to access debug endpoints without Tailscale +// authentication. This will supersede both IsTailscaleIP() and +// TS_ALLOW_DEBUG_IP. +var debugTrustedCIDRs = envknob.RegisterString("TS_DEBUG_TRUSTED_CIDRS") + +// trustedCIDRs returns the parsed CIDR prefixes from TS_DEBUG_TRUSTED_CIDRS. +var trustedCIDRs = sync.OnceValue(func() []netip.Prefix { + return parseTrustedCIDRs(debugTrustedCIDRs()) +}) + +// parseTrustedCIDRs parses a comma-separated list of CIDR prefixes. +// It fatals on invalid entries, consistent with other envknob parsing. +func parseTrustedCIDRs(raw string) []netip.Prefix { + if raw == "" { + return nil + } + var prefixes []netip.Prefix + for _, s := range strings.Split(raw, ",") { + s = strings.TrimSpace(s) + if s == "" { + continue + } + pfx, err := netip.ParsePrefix(s) + if err != nil { + log.Fatalf("invalid CIDR in TS_DEBUG_TRUSTED_CIDRS: %q: %v", s, err) + } + prefixes = append(prefixes, pfx) + } + return prefixes +} + +// cidrsContain checks if the source IP is associated with one of the +// provided cidrs. +func cidrsContain(cidrs []netip.Prefix, ip netip.Addr) bool { + for _, pfx := range cidrs { + if pfx.Contains(ip) { + return true + } + } + return false +} + // AllowDebugAccess reports whether r should be permitted to access // various debug endpoints. func AllowDebugAccess(r *http.Request) bool { @@ -74,6 +120,9 @@ func AllowDebugAccess(r *http.Request) bool { if tsaddr.IsTailscaleIP(ip) || ip.IsLoopback() || ipStr == envknob.String("TS_ALLOW_DEBUG_IP") { return true } + if cidrsContain(trustedCIDRs(), ip) { + return true + } return false } @@ -734,8 +783,8 @@ func (h errorHandler) handleError(w http.ResponseWriter, r *http.Request, lw *lo // Extract a presentable, loggable error. var hOK bool - var hErr HTTPError - if errors.As(err, &hErr) { + hErr, hAsOK := errors.AsType[HTTPError](err) + if hAsOK { hOK = true if hErr.Code == 0 { lw.logf("[unexpected] HTTPError %v did not contain an HTTP status code, sending internal server error", hErr) @@ -854,9 +903,7 @@ func WriteHTTPError(w http.ResponseWriter, r *http.Request, e HTTPError) { h.Set("X-Content-Type-Options", "nosniff") // Custom headers from the error. - for k, vs := range e.Header { - h[k] = vs - } + maps.Copy(h, e.Header) // Write the msg back to the user. w.WriteHeader(e.Code) diff --git a/tsweb/tsweb_test.go b/tsweb/tsweb_test.go index af8e52420bd50..4e09bfffb2ae7 100644 --- a/tsweb/tsweb_test.go +++ b/tsweb/tsweb_test.go @@ -85,7 +85,7 @@ func TestStdHandler(t *testing.T) { wantBody string }{ { - name: "handler returns 200", + name: "handler-returns-200", rh: handlerCode(200), r: req(bgCtx, "http://example.com/"), wantCode: 200, @@ -102,7 +102,7 @@ func TestStdHandler(t *testing.T) { }, { - name: "handler returns 200 with request ID", + name: "handler-returns-200-with-request-ID", rh: handlerCode(200), r: req(bgCtx, "http://example.com/"), wantCode: 200, @@ -119,7 +119,7 @@ func TestStdHandler(t *testing.T) { }, { - name: "handler returns 404", + name: "handler-returns-404", rh: handlerCode(404), r: req(bgCtx, "http://example.com/foo"), wantCode: 404, @@ -135,7 +135,7 @@ func TestStdHandler(t *testing.T) { }, { - name: "handler returns 404 with request ID", + name: "handler-returns-404-with-request-ID", rh: handlerCode(404), r: req(bgCtx, "http://example.com/foo"), wantCode: 404, @@ -151,7 +151,7 @@ func TestStdHandler(t *testing.T) { }, { - name: "handler returns 404 via HTTPError", + name: "handler-returns-404-via-HTTPError", rh: handlerErr(0, Error(404, "not found", testErr)), r: req(bgCtx, "http://example.com/foo"), wantCode: 404, @@ -169,7 +169,7 @@ func TestStdHandler(t *testing.T) { }, { - name: "handler returns 404 via HTTPError with request ID", + name: "handler-returns-404-via-HTTPError-with-request-ID", rh: handlerErr(0, Error(404, "not found", testErr)), r: req(RequestIDKey.WithValue(bgCtx, exampleRequestID), "http://example.com/foo"), wantCode: 404, @@ -188,7 +188,7 @@ func TestStdHandler(t *testing.T) { }, { - name: "handler returns 404 with nil child error", + name: "handler-returns-404-nil-child-error", rh: handlerErr(0, Error(404, "not found", nil)), r: req(bgCtx, "http://example.com/foo"), wantCode: 404, @@ -206,7 +206,7 @@ func TestStdHandler(t *testing.T) { }, { - name: "handler returns 404 with request ID and nil child error", + name: "handler-returns-404-request-ID-nil-child-error", rh: handlerErr(0, Error(404, "not found", nil)), r: req(RequestIDKey.WithValue(bgCtx, exampleRequestID), "http://example.com/foo"), wantCode: 404, @@ -225,7 +225,7 @@ func TestStdHandler(t *testing.T) { }, { - name: "handler returns user-visible error", + name: "handler-returns-user-visible-error", rh: handlerErr(0, vizerror.New("visible error")), r: req(bgCtx, "http://example.com/foo"), wantCode: 500, @@ -243,7 +243,7 @@ func TestStdHandler(t *testing.T) { }, { - name: "handler returns user-visible error with request ID", + name: "handler-returns-user-visible-error-with-request-ID", rh: handlerErr(0, vizerror.New("visible error")), r: req(RequestIDKey.WithValue(bgCtx, exampleRequestID), "http://example.com/foo"), wantCode: 500, @@ -262,7 +262,7 @@ func TestStdHandler(t *testing.T) { }, { - name: "handler returns user-visible error wrapped by private error", + name: "handler-returns-vizerror-wrapped-by-private-error", rh: handlerErr(0, fmt.Errorf("private internal error: %w", vizerror.New("visible error"))), r: req(bgCtx, "http://example.com/foo"), wantCode: 500, @@ -280,7 +280,7 @@ func TestStdHandler(t *testing.T) { }, { - name: "handler returns JSON-formatted HTTPError", + name: "handler-returns-JSON-formatted-HTTPError", rh: ReturnHandlerFunc(func(w http.ResponseWriter, r *http.Request) error { h := Error(http.StatusBadRequest, `{"isjson": true}`, errors.New("uh")) h.Header = http.Header{"Content-Type": {"application/json"}} @@ -303,7 +303,7 @@ func TestStdHandler(t *testing.T) { }, { - name: "handler returns user-visible error wrapped by private error with request ID", + name: "handler-returns-vizerror-wrapped-by-private-error-with-request-ID", rh: handlerErr(0, fmt.Errorf("private internal error: %w", vizerror.New("visible error"))), r: req(RequestIDKey.WithValue(bgCtx, exampleRequestID), "http://example.com/foo"), wantCode: 500, @@ -322,7 +322,7 @@ func TestStdHandler(t *testing.T) { }, { - name: "handler returns generic error", + name: "handler-returns-generic-error", rh: handlerErr(0, testErr), r: req(bgCtx, "http://example.com/foo"), wantCode: 500, @@ -340,7 +340,7 @@ func TestStdHandler(t *testing.T) { }, { - name: "handler returns generic error with request ID", + name: "handler-returns-generic-error-with-request-ID", rh: handlerErr(0, testErr), r: req(RequestIDKey.WithValue(bgCtx, exampleRequestID), "http://example.com/foo"), wantCode: 500, @@ -359,7 +359,7 @@ func TestStdHandler(t *testing.T) { }, { - name: "handler returns error after writing response", + name: "handler-returns-error-after-writing-response", rh: handlerErr(200, testErr), r: req(bgCtx, "http://example.com/foo"), wantCode: 200, @@ -376,7 +376,7 @@ func TestStdHandler(t *testing.T) { }, { - name: "handler returns error after writing response with request ID", + name: "handler-returns-error-after-writing-response-with-request-ID", rh: handlerErr(200, testErr), r: req(RequestIDKey.WithValue(bgCtx, exampleRequestID), "http://example.com/foo"), wantCode: 200, @@ -394,7 +394,7 @@ func TestStdHandler(t *testing.T) { }, { - name: "handler returns HTTPError after writing response", + name: "handler-returns-HTTPError-after-writing-response", rh: handlerErr(200, Error(404, "not found", testErr)), r: req(bgCtx, "http://example.com/foo"), wantCode: 200, @@ -411,7 +411,7 @@ func TestStdHandler(t *testing.T) { }, { - name: "handler does nothing", + name: "handler-does-nothing", rh: handlerFunc(func(http.ResponseWriter, *http.Request) error { return nil }), r: req(bgCtx, "http://example.com/foo"), wantCode: 200, @@ -427,7 +427,7 @@ func TestStdHandler(t *testing.T) { }, { - name: "handler hijacks conn", + name: "handler-hijacks-conn", rh: handlerFunc(func(w http.ResponseWriter, r *http.Request) error { _, _, err := w.(http.Hijacker).Hijack() if err != nil { @@ -450,7 +450,7 @@ func TestStdHandler(t *testing.T) { }, { - name: "error handler gets run", + name: "error-handler-gets-run", rh: handlerErr(0, Error(404, "not found", nil)), // status code changed in errHandler r: req(bgCtx, "http://example.com/"), wantCode: 200, @@ -472,7 +472,7 @@ func TestStdHandler(t *testing.T) { }, { - name: "error handler gets run with request ID", + name: "error-handler-gets-run-with-request-ID", rh: handlerErr(0, Error(404, "not found", nil)), // status code changed in errHandler r: req(RequestIDKey.WithValue(bgCtx, exampleRequestID), "http://example.com/"), wantCode: 200, diff --git a/tsweb/varz/varz.go b/tsweb/varz/varz.go index a2286c7603be3..0df6e57751a7e 100644 --- a/tsweb/varz/varz.go +++ b/tsweb/varz/varz.go @@ -93,8 +93,8 @@ func prometheusMetric(prefix string, key string) (string, string, string) { typ = "histogram" key = strings.TrimPrefix(key, histogramPrefix) } - if strings.HasPrefix(key, labelMapPrefix) { - key = strings.TrimPrefix(key, labelMapPrefix) + if after, ok := strings.CutPrefix(key, labelMapPrefix); ok { + key = after if a, b, ok := strings.Cut(key, "_"); ok { label, key = a, b } @@ -154,7 +154,7 @@ func writePromExpVar(w io.Writer, prefix string, kv expvar.KeyValue) { case PrometheusMetricsReflectRooter: root := v.PrometheusMetricsReflectRoot() rv := reflect.ValueOf(root) - if rv.Type().Kind() == reflect.Ptr { + if rv.Type().Kind() == reflect.Pointer { if rv.IsNil() { return } @@ -419,8 +419,7 @@ func structTypeSortedFields(t reflect.Type) []sortedStructField { return v.([]sortedStructField) } fields := make([]sortedStructField, 0, t.NumField()) - for i, n := 0, t.NumField(); i < n; i++ { - sf := t.Field(i) + for sf := range t.Fields() { name := sf.Name if v := sf.Tag.Get("json"); v != "" { v, _, _ = strings.Cut(v, ",") @@ -433,7 +432,7 @@ func structTypeSortedFields(t reflect.Type) []sortedStructField { } } fields = append(fields, sortedStructField{ - Index: i, + Index: sf.Index[0], Name: name, SortName: removeTypePrefixes(name), MetricType: sf.Tag.Get("metrictype"), @@ -467,7 +466,7 @@ func foreachExportedStructField(rv reflect.Value, f func(fieldOrJSONName, metric sf := ssf.StructFieldType if ssf.MetricType != "" || sf.Type.Kind() == reflect.Struct { f(ssf.Name, ssf.MetricType, rv.Field(ssf.Index)) - } else if sf.Type.Kind() == reflect.Ptr && sf.Type.Elem().Kind() == reflect.Struct { + } else if sf.Type.Kind() == reflect.Pointer && sf.Type.Elem().Kind() == reflect.Struct { fv := rv.Field(ssf.Index) if !fv.IsNil() { f(ssf.Name, ssf.MetricType, fv.Elem()) diff --git a/tsweb/varz/varz_test.go b/tsweb/varz/varz_test.go index d041edb4b93d4..27094e77bf01d 100644 --- a/tsweb/varz/varz_test.go +++ b/tsweb/varz/varz_test.go @@ -205,7 +205,7 @@ func TestVarzHandler(t *testing.T) { "string_map", func() *expvar.Map { m := new(expvar.Map) - m.Set("a", expvar.NewString("foo")) + m.Set("a", new(expvar.String)) return m }(), "# skipping \"string_map\" expvar map key \"a\" with unknown value type *expvar.String\n", diff --git a/types/appctype/appconnector.go b/types/appctype/appconnector.go index 0af5db4c38672..b0fd5e65ab6c2 100644 --- a/types/appctype/appconnector.go +++ b/types/appctype/appconnector.go @@ -104,7 +104,9 @@ type Conn25Attr struct { // Connectors enumerates the app connectors which service these domains. // These can either be "*" to match any advertising connector, or a // tag of the form tag:. - Connectors []string `json:"connectors,omitempty"` - MagicIPPool []netipx.IPRange `json:"magicIPPool,omitempty"` - TransitIPPool []netipx.IPRange `json:"transitIPPool,omitempty"` + Connectors []string `json:"connectors,omitempty"` + V4MagicIPPool []netipx.IPRange `json:"v4MagicIPPool,omitempty"` + V4TransitIPPool []netipx.IPRange `json:"v4TransitIPPool,omitempty"` + V6MagicIPPool []netipx.IPRange `json:"v6MagicIPPool,omitempty"` + V6TransitIPPool []netipx.IPRange `json:"v6TransitIPPool,omitempty"` } diff --git a/types/dnstype/dnstype_test.go b/types/dnstype/dnstype_test.go index cf20f4f7f6618..8f746ab76bec7 100644 --- a/types/dnstype/dnstype_test.go +++ b/types/dnstype/dnstype_test.go @@ -33,13 +33,13 @@ func TestResolverEqual(t *testing.T) { want: true, }, { - name: "nil vs non-nil", + name: "nil-vs-non-nil", a: nil, b: &Resolver{}, want: false, }, { - name: "non-nil vs nil", + name: "non-nil-vs-nil", a: &Resolver{}, b: nil, want: false, @@ -51,13 +51,13 @@ func TestResolverEqual(t *testing.T) { want: true, }, { - name: "not equal addrs", + name: "not-equal-addrs", a: &Resolver{Addr: "dns.example.com"}, b: &Resolver{Addr: "dns2.example.com"}, want: false, }, { - name: "not equal bootstrap", + name: "not-equal-bootstrap", a: &Resolver{ Addr: "dns.example.com", BootstrapResolution: []netip.Addr{netip.MustParseAddr("8.8.8.8")}, @@ -69,13 +69,13 @@ func TestResolverEqual(t *testing.T) { want: false, }, { - name: "equal UseWithExitNode", + name: "equal-UseWithExitNode", a: &Resolver{Addr: "dns.example.com", UseWithExitNode: true}, b: &Resolver{Addr: "dns.example.com", UseWithExitNode: true}, want: true, }, { - name: "not equal UseWithExitNode", + name: "not-equal-UseWithExitNode", a: &Resolver{Addr: "dns.example.com", UseWithExitNode: true}, b: &Resolver{Addr: "dns.example.com", UseWithExitNode: false}, want: false, diff --git a/types/events/disco_update.go b/types/events/disco_update.go new file mode 100644 index 0000000000000..206c554a1d7f0 --- /dev/null +++ b/types/events/disco_update.go @@ -0,0 +1,30 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// Package events contains type used as eventbus topics in tailscaled. +package events + +import ( + "net/netip" + + "tailscale.com/types/key" +) + +// DiscoKeyAdvertisement is an event sent on the [eventbus.Bus] when a disco +// key has been received over TSMP. +// +// Its publisher is [tstun.Wrapper]; its main subscriber is +// [controlclient.Direct], that injects the received key into the netmap as if +// it was a netmap update from control. +type DiscoKeyAdvertisement struct { + Src netip.Addr // Src field is populated by the IP header of the packet, not from the payload itself. + Key key.DiscoPublic +} + +// PeerDiscoKeyUpdate is an event sent on the [eventbus.Bus] when +// [controlclient.Direct] deems that it cannot handle the key update. +// +// Its publisher is [controlclient.Direct]; its main subscriber is +// [wgengine.userspaceengine], that injects the received key into its +// [magicsock.Conn] in order to set up the key directly. +type PeerDiscoKeyUpdate DiscoKeyAdvertisement diff --git a/types/ipproto/ipproto_test.go b/types/ipproto/ipproto_test.go index 8bfeb13fa4246..6d8be47a9046c 100644 --- a/types/ipproto/ipproto_test.go +++ b/types/ipproto/ipproto_test.go @@ -69,7 +69,7 @@ func TestProtoUnmarshalText(t *testing.T) { for i := range 256 { var p Proto - must.Do(p.UnmarshalText([]byte(fmt.Sprintf("%d", i)))) + must.Do(p.UnmarshalText(fmt.Appendf(nil, "%d", i))) if got, want := p, Proto(i); got != want { t.Errorf("Proto(%d) = %v, want %v", i, got, want) } @@ -122,7 +122,7 @@ func TestProtoUnmarshalJSON(t *testing.T) { var p Proto for i := range 256 { - j := []byte(fmt.Sprintf(`%d`, i)) + j := fmt.Appendf(nil, `%d`, i) must.Do(json.Unmarshal(j, &p)) if got, want := p, Proto(i); got != want { t.Errorf("Proto(%d) = %v, want %v", i, got, want) @@ -130,7 +130,7 @@ func TestProtoUnmarshalJSON(t *testing.T) { } for name, wantProto := range acceptedNames { - must.Do(json.Unmarshal([]byte(fmt.Sprintf(`"%s"`, name)), &p)) + must.Do(json.Unmarshal(fmt.Appendf(nil, `"%s"`, name), &p)) if got, want := p, wantProto; got != want { t.Errorf("Proto(%q) = %v, want %v", name, got, want) } diff --git a/types/jsonx/json_test.go b/types/jsonx/json_test.go index 5c302d9746c3e..8b0abbab64686 100644 --- a/types/jsonx/json_test.go +++ b/types/jsonx/json_test.go @@ -10,7 +10,6 @@ import ( "github.com/go-json-experiment/json" "github.com/go-json-experiment/json/jsontext" "github.com/google/go-cmp/cmp" - "tailscale.com/types/ptr" ) type Interface interface { @@ -72,7 +71,7 @@ func TestInterfaceCoders(t *testing.T) { wantJSON: `{"Foo":"hello"}`, }, { label: "BarPointer", - wantVal: InterfaceWrapper{ptr.To(Bar(5))}, + wantVal: InterfaceWrapper{new(Bar(5))}, wantJSON: `{"Bar":5}`, }, { label: "BarValue", diff --git a/types/key/disco.go b/types/key/disco.go index f46347c919ebb..7fa476dc35ec0 100644 --- a/types/key/disco.go +++ b/types/key/disco.go @@ -42,6 +42,16 @@ func NewDisco() DiscoPrivate { return ret } +// DiscoPrivateFromRaw32 parses a 32-byte raw value as a DiscoPrivate. +func DiscoPrivateFromRaw32(raw mem.RO) DiscoPrivate { + if raw.Len() != 32 { + panic("input has wrong size") + } + var ret DiscoPrivate + raw.Copy(ret.k[:]) + return ret +} + // IsZero reports whether k is the zero value. func (k DiscoPrivate) IsZero() bool { return k.Equal(DiscoPrivate{}) diff --git a/types/key/nl.go b/types/key/nl.go index fc11d5b20ff64..0e8c5ed966260 100644 --- a/types/key/nl.go +++ b/types/key/nl.go @@ -119,7 +119,7 @@ type NLPublic struct { // NLPublicFromEd25519Unsafe converts an ed25519 public key into // a type of NLPublic. // -// New uses of this function should be avoided, as its possible to +// New uses of this function should be avoided, as it's possible to // accidentally construct an NLPublic from a non network-lock key. func NLPublicFromEd25519Unsafe(public ed25519.PublicKey) NLPublic { var out NLPublic diff --git a/types/key/node.go b/types/key/node.go index 1402aad361870..a1d8e47bafce1 100644 --- a/types/key/node.go +++ b/types/key/node.go @@ -15,6 +15,7 @@ import ( "golang.org/x/crypto/curve25519" "golang.org/x/crypto/nacl/box" "tailscale.com/types/structs" + "tailscale.com/util/bufiox" ) const ( @@ -61,6 +62,14 @@ func NewNode() NodePrivate { return ret } +// Raw32 returns k as 32 raw bytes. +func (k NodePrivate) Raw32() [32]byte { return k.k } + +// NodePrivateAs returns a NodePrivate as a named fixed-size array of bytes. +// It's intended for interoperability with wireguard-go's +// device.NoisePrivateKey type. +func NodePrivateAs[T ~[32]byte](k NodePrivate) T { return k.k } + // NodePrivateFromRaw32 parses a 32-byte raw value as a NodePrivate. // // Deprecated: only needed to cast from legacy node private key types, @@ -239,42 +248,34 @@ func (k NodePublic) AppendTo(buf []byte) []byte { } // ReadRawWithoutAllocating initializes k with bytes read from br. -// The reading is done ~4x slower than io.ReadFull, but in exchange is -// allocation-free. +// It uses [bufiox.ReadFull] to read without heap allocations. func (k *NodePublic) ReadRawWithoutAllocating(br *bufio.Reader) error { var z NodePublic if *k != z { return errors.New("refusing to read into non-zero NodePublic") } - // This is ~4x slower than io.ReadFull, but using io.ReadFull - // causes one extra alloc, which is significant for the DERP - // server that consumes this method. So, process stuff slower but - // without allocation. - // - // Dear future: if io.ReadFull stops causing stuff to escape, you - // should switch back to that. - for i := range k.k { - b, err := br.ReadByte() - if err != nil { - return err - } - k.k[i] = b - } - return nil + _, err := bufiox.ReadFull(br, k.k[:]) + return err } -// WriteRawWithoutAllocating writes out k as 32 bytes to bw. -// The writing is done ~3x slower than bw.Write, but in exchange is -// allocation-free. +// WriteRawWithoutAllocating writes out k as 32 big-endian bytes to bw. +// +// It uses AvailableBuffer to append directly into bufio's internal +// buffer without allocation, falling back to WriteByte when the +// buffer has insufficient space. func (k NodePublic) WriteRawWithoutAllocating(bw *bufio.Writer) error { - // Equivalent to bw.Write(k.k[:]), but without causing an - // escape-related alloc. - // - // Dear future: if bw.Write(k.k[:]) stops causing stuff to escape, - // you should switch back to that. + // Fast path: enough space in the buffer to append directly. + if bw.Available() >= len(k.k) { + buf := bw.AvailableBuffer() + buf = append(buf, k.k[:]...) + _, err := bw.Write(buf) + return err + } + // Slow path: buffer nearly full. Write byte-at-a-time to let + // bufio flush as needed, avoiding a heap allocation from append + // growing past AvailableBuffer's capacity. for _, b := range k.k { - err := bw.WriteByte(b) - if err != nil { + if err := bw.WriteByte(b); err != nil { return err } } diff --git a/types/key/node_test.go b/types/key/node_test.go index 77eef2b28d2f5..020ddd1f1ff6d 100644 --- a/types/key/node_test.go +++ b/types/key/node_test.go @@ -7,6 +7,7 @@ import ( "bufio" "bytes" "encoding/json" + "io" "strings" "testing" ) @@ -125,20 +126,91 @@ func TestNodeReadRawWithoutAllocating(t *testing.T) { } } -func TestNodeWriteRawWithoutAllocating(t *testing.T) { - buf := make([]byte, 0, 32) - w := bytes.NewBuffer(buf) - bw := bufio.NewWriter(w) - got := testing.AllocsPerRun(1000, func() { - w.Reset() - bw.Reset(w) +func BenchmarkNodeReadRawWithoutAllocating(b *testing.B) { + buf := make([]byte, 32) + for i := range buf { + buf[i] = 0x42 + } + r := bytes.NewReader(buf) + br := bufio.NewReader(r) + b.ReportAllocs() + for b.Loop() { + r.Reset(buf) + br.Reset(r) var k NodePublic + if err := k.ReadRawWithoutAllocating(br); err != nil { + b.Fatal(err) + } + } +} + +func TestNodeWriteRawWithoutAllocating(t *testing.T) { + var k NodePublic + for i := range k.k { + k.k[i] = byte(i) + } + + // Test fast path (empty buffer, plenty of space). + t.Run("fast", func(t *testing.T) { + var buf bytes.Buffer + bw := bufio.NewWriter(&buf) if err := k.WriteRawWithoutAllocating(bw); err != nil { t.Fatalf("WriteRawWithoutAllocating: %v", err) } + bw.Flush() + if got := buf.Bytes(); !bytes.Equal(got, k.k[:]) { + t.Errorf("wrote % 02x, want % 02x", got, k.k) + } }) - if want := 0.0; got != want { - t.Fatalf("WriteRawWithoutAllocating got %f allocs, want %f", got, want) + + // Test slow path (buffer nearly full, less than 32 bytes available). + t.Run("slow", func(t *testing.T) { + var buf bytes.Buffer + const smallBuf = 40 + bw := bufio.NewWriterSize(&buf, smallBuf) + // Fill buffer to leave less than 32 bytes available. + padding := make([]byte, smallBuf-len(k.k)+1) + if _, err := bw.Write(padding); err != nil { + t.Fatalf("Write padding: %v", err) + } + if err := k.WriteRawWithoutAllocating(bw); err != nil { + t.Fatalf("WriteRawWithoutAllocating: %v", err) + } + bw.Flush() + got := buf.Bytes()[len(padding):] + if !bytes.Equal(got, k.k[:]) { + t.Errorf("wrote % 02x, want % 02x", got, k.k) + } + }) + + // Verify zero allocations on fast path. + t.Run("allocs", func(t *testing.T) { + w := bytes.NewBuffer(make([]byte, 0, 32)) + bw := bufio.NewWriter(w) + got := testing.AllocsPerRun(1000, func() { + w.Reset() + bw.Reset(w) + if err := k.WriteRawWithoutAllocating(bw); err != nil { + t.Fatalf("WriteRawWithoutAllocating: %v", err) + } + }) + if got != 0 { + t.Fatalf("WriteRawWithoutAllocating allocs = %f, want 0", got) + } + }) +} + +func BenchmarkNodeWriteRawWithoutAllocating(b *testing.B) { + bw := bufio.NewWriter(io.Discard) + var k NodePublic + for i := range k.k { + k.k[i] = 0x42 + } + b.ReportAllocs() + for b.Loop() { + if err := k.WriteRawWithoutAllocating(bw); err != nil { + b.Fatal(err) + } } } diff --git a/types/lazy/deferred.go b/types/lazy/deferred.go index 582090ab93112..6e96f61e7af04 100644 --- a/types/lazy/deferred.go +++ b/types/lazy/deferred.go @@ -6,8 +6,6 @@ package lazy import ( "sync" "sync/atomic" - - "tailscale.com/types/ptr" ) // DeferredInit allows one or more funcs to be deferred @@ -91,7 +89,7 @@ func (d *DeferredInit) doSlow() (err *error) { }() for _, f := range d.funcs { if err := f(); err != nil { - return ptr.To(err) + return new(err) } } return nilErrPtr diff --git a/types/lazy/deferred_test.go b/types/lazy/deferred_test.go index 61cc8f8ac6c27..4b2bb07ee2fbd 100644 --- a/types/lazy/deferred_test.go +++ b/types/lazy/deferred_test.go @@ -145,13 +145,11 @@ func TestDeferredInit(t *testing.T) { // Call [DeferredInit.Do] concurrently. const N = 10000 for range N { - wg.Add(1) - go func() { + wg.Go(func() { gotErr := di.Do() checkError(t, gotErr, nil, false) checkCalls() - wg.Done() - }() + }) } wg.Wait() }) @@ -193,12 +191,10 @@ func TestDeferredErr(t *testing.T) { var wg sync.WaitGroup N := 10000 for range N { - wg.Add(1) - go func() { + wg.Go(func() { gotErr := di.Do() checkError(t, gotErr, tt.wantErr, false) - wg.Done() - }() + }) } wg.Wait() }) @@ -254,11 +250,9 @@ func TestDeferAfterDo(t *testing.T) { const N = 10000 var wg sync.WaitGroup for range N { - wg.Add(1) - go func() { + wg.Go(func() { deferOnce() - wg.Done() - }() + }) } if err := di.Do(); err != nil { diff --git a/types/lazy/lazy.go b/types/lazy/lazy.go index 915ae2002c135..a24139fe1ab07 100644 --- a/types/lazy/lazy.go +++ b/types/lazy/lazy.go @@ -7,13 +7,11 @@ package lazy import ( "sync" "sync/atomic" - - "tailscale.com/types/ptr" ) // nilErrPtr is a sentinel *error value for SyncValue.err to signal // that SyncValue.v is valid. -var nilErrPtr = ptr.To[error](nil) +var nilErrPtr = new(error(nil)) // SyncValue is a lazily computed value. // @@ -80,7 +78,7 @@ func (z *SyncValue[T]) GetErr(fill func() (T, error)) (T, error) { // Update z.err after z.v; see field docs. if err != nil { - z.err.Store(ptr.To(err)) + z.err.Store(new(err)) } else { z.err.Store(nilErrPtr) } @@ -145,7 +143,7 @@ func (z *SyncValue[T]) SetForTest(tb testing_TB, val T, err error) { z.v = val if err != nil { - z.err.Store(ptr.To(err)) + z.err.Store(new(err)) } else { z.err.Store(nilErrPtr) } diff --git a/types/netmap/netmap.go b/types/netmap/netmap.go index ac95254daee1d..fbf415be0a95b 100644 --- a/types/netmap/netmap.go +++ b/types/netmap/netmap.go @@ -146,6 +146,34 @@ func (nm *NetworkMap) GetIPVIPServiceMap() IPServiceMappings { return res } +// Services returns the Services visible (accessible) to this node, +// decoded from [tailcfg.NodeAttrPrefixServices] entries in the self node's +// CapMap. The returned map is keyed by [tailcfg.ServiceDetails.Name], +// which is the canonical service name. It returns nil if nm is nil +// or SelfNode is invalid. +// +// TODO(adrianosela): cache the result of decoding the capmap so +// we don't have to decode it multiple times after each netmap update. +func (nm *NetworkMap) Services() map[tailcfg.ServiceName]tailcfg.ServiceDetails { + if nm == nil || !nm.SelfNode.Valid() { + return nil + } + result := make(map[tailcfg.ServiceName]tailcfg.ServiceDetails) + for cap := range nm.SelfNode.CapMap().All() { + if !strings.HasPrefix(string(cap), string(tailcfg.NodeAttrPrefixServices)) { + continue + } + svcs, err := tailcfg.UnmarshalNodeCapViewJSON[tailcfg.ServiceDetails](nm.SelfNode.CapMap(), cap) + if err != nil || len(svcs) < 1 { + continue + } + // NOTE(adrianosela): the NodeCapMap key suffix is opaque and MUST not + // be parsed or relied upon (so we extract name from the inner field). + result[svcs[0].Name] = svcs[0] + } + return result +} + // SelfNodeOrZero returns the self node, or a zero value if nm is nil. func (nm *NetworkMap) SelfNodeOrZero() tailcfg.NodeView { if nm == nil { @@ -284,13 +312,6 @@ func (nm *NetworkMap) TailnetDisplayName() string { return tailnetDisplayNames[0] } -// HasSelfCapability reports whether nm.SelfNode contains capability c. -// -// It exists to satisify an unused (as of 2025-01-04) interface in the logknob package. -func (nm *NetworkMap) HasSelfCapability(c tailcfg.NodeCapability) bool { - return nm.AllCaps.Contains(c) -} - func (nm *NetworkMap) String() string { return nm.Concise() } diff --git a/types/netmap/nodemut.go b/types/netmap/nodemut.go index 5c9000d56ef38..901296b1fc337 100644 --- a/types/netmap/nodemut.go +++ b/types/netmap/nodemut.go @@ -12,7 +12,6 @@ import ( "time" "tailscale.com/tailcfg" - "tailscale.com/types/ptr" ) // NodeMutation is the common interface for types that describe @@ -55,7 +54,7 @@ type NodeMutationOnline struct { } func (m NodeMutationOnline) Apply(n *tailcfg.Node) { - n.Online = ptr.To(m.Online) + n.Online = new(m.Online) } // NodeMutationLastSeen is a NodeMutation that says a node's LastSeen @@ -66,14 +65,14 @@ type NodeMutationLastSeen struct { } func (m NodeMutationLastSeen) Apply(n *tailcfg.Node) { - n.LastSeen = ptr.To(m.LastSeen) + n.LastSeen = new(m.LastSeen) } var peerChangeFields = sync.OnceValue(func() []reflect.StructField { var fields []reflect.StructField rt := reflect.TypeFor[tailcfg.PeerChange]() - for i := range rt.NumField() { - fields = append(fields, rt.Field(i)) + for field := range rt.Fields() { + fields = append(fields, field) } return fields }) diff --git a/types/netmap/nodemut_test.go b/types/netmap/nodemut_test.go index f7302d48df097..1ae2ab1f98bdd 100644 --- a/types/netmap/nodemut_test.go +++ b/types/netmap/nodemut_test.go @@ -14,7 +14,6 @@ import ( "tailscale.com/tailcfg" "tailscale.com/types/logger" "tailscale.com/types/opt" - "tailscale.com/types/ptr" ) // tests mapResponseContainsNonPatchFields @@ -35,7 +34,7 @@ func TestMapResponseContainsNonPatchFields(t *testing.T) { return reflect.ValueOf(int64(1)).Convert(t) case reflect.Slice: return reflect.MakeSlice(t, 1, 1) - case reflect.Ptr: + case reflect.Pointer: return reflect.New(t.Elem()) case reflect.Map: return reflect.MakeMap(t) @@ -44,8 +43,7 @@ func TestMapResponseContainsNonPatchFields(t *testing.T) { } rt := reflect.TypeFor[tailcfg.MapResponse]() - for i := range rt.NumField() { - f := rt.Field(i) + for f := range rt.Fields() { var want bool switch f.Name { @@ -117,7 +115,7 @@ func TestMutationsFromMapResponse(t *testing.T) { name: "patch-online", mr: fromChanges(&tailcfg.PeerChange{ NodeID: 1, - Online: ptr.To(true), + Online: new(true), }), want: muts(NodeMutationOnline{1, true}), }, @@ -125,7 +123,7 @@ func TestMutationsFromMapResponse(t *testing.T) { name: "patch-online-false", mr: fromChanges(&tailcfg.PeerChange{ NodeID: 1, - Online: ptr.To(false), + Online: new(false), }), want: muts(NodeMutationOnline{1, false}), }, @@ -133,7 +131,7 @@ func TestMutationsFromMapResponse(t *testing.T) { name: "patch-lastseen", mr: fromChanges(&tailcfg.PeerChange{ NodeID: 1, - LastSeen: ptr.To(time.Unix(12345, 0)), + LastSeen: new(time.Unix(12345, 0)), }), want: muts(NodeMutationLastSeen{1, time.Unix(12345, 0)}), }, diff --git a/types/persist/persist_clone.go b/types/persist/persist_clone.go index f5fa36b6da0fc..b43dcc7fd979e 100644 --- a/types/persist/persist_clone.go +++ b/types/persist/persist_clone.go @@ -19,6 +19,7 @@ func (src *Persist) Clone() *Persist { } dst := new(Persist) *dst = *src + dst.UserProfile = *src.UserProfile.Clone() if src.AttestationKey != nil { dst.AttestationKey = src.AttestationKey.Clone() } diff --git a/types/persist/persist_test.go b/types/persist/persist_test.go index b25af5a0b2066..33773013d667f 100644 --- a/types/persist/persist_test.go +++ b/types/persist/persist_test.go @@ -12,8 +12,8 @@ import ( ) func fieldsOf(t reflect.Type) (fields []string) { - for i := range t.NumField() { - if name := t.Field(i).Name; name != "_" { + for field := range t.Fields() { + if name := field.Name; name != "_" { fields = append(fields, name) } } diff --git a/types/persist/persist_view.go b/types/persist/persist_view.go index b18634917c651..f33d222c6fb8d 100644 --- a/types/persist/persist_view.go +++ b/types/persist/persist_view.go @@ -90,7 +90,7 @@ func (v PersistView) PrivateNodeKey() key.NodePrivate { return v.Đļ.PrivateNodeK // needed to request key rotation func (v PersistView) OldPrivateNodeKey() key.NodePrivate { return v.Đļ.OldPrivateNodeKey } -func (v PersistView) UserProfile() tailcfg.UserProfile { return v.Đļ.UserProfile } +func (v PersistView) UserProfile() tailcfg.UserProfileView { return v.Đļ.UserProfile.View() } func (v PersistView) NetworkLockKey() key.NLPrivate { return v.Đļ.NetworkLockKey } func (v PersistView) NodeID() tailcfg.StableNodeID { return v.Đļ.NodeID } func (v PersistView) AttestationKey() tailcfg.StableNodeID { panic("unsupported") } diff --git a/types/prefs/item.go b/types/prefs/item.go index fdb9301f9fdf8..564e8ffde7d0f 100644 --- a/types/prefs/item.go +++ b/types/prefs/item.go @@ -9,7 +9,6 @@ import ( jsonv2 "github.com/go-json-experiment/json" "github.com/go-json-experiment/json/jsontext" "tailscale.com/types/opt" - "tailscale.com/types/ptr" "tailscale.com/types/views" "tailscale.com/util/must" ) @@ -47,7 +46,7 @@ func (i *Item[T]) SetManagedValue(val T) { // It is a runtime error to call [Item.Clone] if T contains pointers // but does not implement [views.Cloner]. func (i Item[T]) Clone() *Item[T] { - res := ptr.To(i) + res := new(i) if v, ok := i.ValueOk(); ok { res.s.Value.Set(must.Get(deepClone(v))) } diff --git a/types/prefs/list.go b/types/prefs/list.go index 20e4dad463135..c6881991ad769 100644 --- a/types/prefs/list.go +++ b/types/prefs/list.go @@ -12,7 +12,6 @@ import ( "github.com/go-json-experiment/json/jsontext" "golang.org/x/exp/constraints" "tailscale.com/types/opt" - "tailscale.com/types/ptr" "tailscale.com/types/views" ) @@ -62,7 +61,7 @@ func (ls *List[T]) View() ListView[T] { // Clone returns a copy of l that aliases no memory with l. func (ls List[T]) Clone() *List[T] { - res := ptr.To(ls) + res := new(ls) if v, ok := ls.s.Value.GetOk(); ok { res.s.Value.Set(append(v[:0:0], v...)) } diff --git a/types/prefs/map.go b/types/prefs/map.go index 6bf1948b87ab4..07cb84f0da56a 100644 --- a/types/prefs/map.go +++ b/types/prefs/map.go @@ -11,7 +11,6 @@ import ( "github.com/go-json-experiment/json/jsontext" "golang.org/x/exp/constraints" "tailscale.com/types/opt" - "tailscale.com/types/ptr" "tailscale.com/types/views" ) @@ -44,7 +43,7 @@ func (m *Map[K, V]) View() MapView[K, V] { // Clone returns a copy of m that aliases no memory with m. func (m Map[K, V]) Clone() *Map[K, V] { - res := ptr.To(m) + res := new(m) if v, ok := m.s.Value.GetOk(); ok { res.s.Value.Set(maps.Clone(v)) } diff --git a/types/prefs/prefs_clone_test.go b/types/prefs/prefs_clone_test.go index 07dc24fdc7361..1914a0c2551f6 100644 --- a/types/prefs/prefs_clone_test.go +++ b/types/prefs/prefs_clone_test.go @@ -7,8 +7,6 @@ package prefs import ( "net/netip" - - "tailscale.com/types/ptr" ) // Clone makes a deep copy of TestPrefs. @@ -67,7 +65,7 @@ func (src *TestBundle) Clone() *TestBundle { dst := new(TestBundle) *dst = *src if dst.Nested != nil { - dst.Nested = ptr.To(*src.Nested) + dst.Nested = new(*src.Nested) } return dst } diff --git a/types/prefs/struct_list.go b/types/prefs/struct_list.go index 09aa808ccc37e..e1c1863fc5dc1 100644 --- a/types/prefs/struct_list.go +++ b/types/prefs/struct_list.go @@ -11,7 +11,6 @@ import ( jsonv2 "github.com/go-json-experiment/json" "github.com/go-json-experiment/json/jsontext" "tailscale.com/types/opt" - "tailscale.com/types/ptr" "tailscale.com/types/views" ) @@ -45,7 +44,7 @@ func (ls *StructList[T]) SetManagedValue(val []T) { // Clone returns a copy of l that aliases no memory with l. func (ls StructList[T]) Clone() *StructList[T] { - res := ptr.To(ls) + res := new(ls) if v, ok := ls.s.Value.GetOk(); ok { res.s.Value.Set(deepCloneSlice(v)) } diff --git a/types/prefs/struct_map.go b/types/prefs/struct_map.go index 2f2715a62a94a..374d8a92ee925 100644 --- a/types/prefs/struct_map.go +++ b/types/prefs/struct_map.go @@ -9,7 +9,6 @@ import ( jsonv2 "github.com/go-json-experiment/json" "github.com/go-json-experiment/json/jsontext" "tailscale.com/types/opt" - "tailscale.com/types/ptr" "tailscale.com/types/views" ) @@ -43,7 +42,7 @@ func (m *StructMap[K, V]) SetManagedValue(val map[K]V) { // Clone returns a copy of m that aliases no memory with m. func (m StructMap[K, V]) Clone() *StructMap[K, V] { - res := ptr.To(m) + res := new(m) if v, ok := m.s.Value.GetOk(); ok { res.s.Value.Set(deepCloneMap(v)) } diff --git a/types/ptr/ptr.go b/types/ptr/ptr.go index 5b65a0e1c13e7..ba2b9e5857e8f 100644 --- a/types/ptr/ptr.go +++ b/types/ptr/ptr.go @@ -2,9 +2,18 @@ // SPDX-License-Identifier: BSD-3-Clause // Package ptr contains the ptr.To function. +// +// Deprecated: Use Go 1.26's new(value) expression instead. +// See https://go.dev/doc/go1.26#language. package ptr // To returns a pointer to a shallow copy of v. +// +// Deprecated: Use Go 1.26's new(value) expression instead. +// For example, ptr.To(42) can be written as new(42). +// See https://go.dev/doc/go1.26#language. +// +//go:fix inline func To[T any](v T) *T { - return &v + return new(v) } diff --git a/types/views/views.go b/types/views/views.go index 9260311edc29a..fe70e227fc64c 100644 --- a/types/views/views.go +++ b/types/views/views.go @@ -19,7 +19,6 @@ import ( jsonv2 "github.com/go-json-experiment/json" "github.com/go-json-experiment/json/jsontext" "go4.org/mem" - "tailscale.com/types/ptr" ) // ByteSlice is a read-only accessor for types that are backed by a []byte. @@ -901,7 +900,7 @@ func (p ValuePointer[T]) Clone() *T { if p.Đļ == nil { return nil } - return ptr.To(*p.Đļ) + return new(*p.Đļ) } // String implements [fmt.Stringer]. @@ -969,8 +968,8 @@ func containsPointers(typ reflect.Type) bool { if isWellKnownImmutableStruct(typ) { return false } - for i := range typ.NumField() { - if containsPointers(typ.Field(i).Type) { + for field := range typ.Fields() { + if containsPointers(field.Type) { return true } } diff --git a/update-flake.sh b/update-flake.sh deleted file mode 100755 index c22572b860248..0000000000000 --- a/update-flake.sh +++ /dev/null @@ -1,26 +0,0 @@ -#!/bin/sh -# Updates SRI hashes for flake.nix. - -set -eu - -OUT=$(mktemp -d -t nar-hash-XXXXXX) -rm -rf "$OUT" - -./tool/go mod vendor -o "$OUT" -./tool/go run tailscale.com/cmd/nardump --sri "$OUT" >go.mod.sri -rm -rf "$OUT" - -GOOUT=$(mktemp -d -t gocross-XXXXXX) -GOREV=$(xargs < ./go.toolchain.rev) -TARBALL="$GOOUT/go-$GOREV.tar.gz" -curl -Ls -o "$TARBALL" "https://github.com/tailscale/go/archive/$GOREV.tar.gz" -tar -xzf "$TARBALL" -C "$GOOUT" -./tool/go run tailscale.com/cmd/nardump --sri "$GOOUT/go-$GOREV" > go.toolchain.rev.sri -rm -rf "$GOOUT" - -# nix-direnv only watches the top-level nix file for changes. As a -# result, when we change a referenced SRI file, we have to cause some -# change to shell.nix and flake.nix as well, so that nix-direnv -# notices and reevaluates everything. Sigh. -perl -pi -e "s,# nix-direnv cache busting line:.*,# nix-direnv cache busting line: $(cat go.mod.sri)," shell.nix -perl -pi -e "s,# nix-direnv cache busting line:.*,# nix-direnv cache busting line: $(cat go.mod.sri)," flake.nix diff --git a/util/bufiox/bufiox.go b/util/bufiox/bufiox.go new file mode 100644 index 0000000000000..ceea9aefe226f --- /dev/null +++ b/util/bufiox/bufiox.go @@ -0,0 +1,31 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// Package bufiox provides extensions to the standard bufio package. +package bufiox + +import "io" + +// BufferedReader is an interface for readers that support peeking +// into an internal buffer, like [bufio.Reader]. +type BufferedReader interface { + Peek(n int) ([]byte, error) + Discard(n int) (discarded int, err error) +} + +// ReadFull reads exactly len(buf) bytes from r into buf, like +// [io.ReadFull], but without heap allocations. It uses Peek to +// access the buffered data directly, copies it into buf, then +// discards the consumed bytes. If an error occurs, +// discard is not called and the buffer is left unchanged. +func ReadFull(r BufferedReader, buf []byte) (int, error) { + b, err := r.Peek(len(buf)) + if err != nil { + if len(b) > 0 && err == io.EOF { + err = io.ErrUnexpectedEOF + } + return 0, err + } + defer r.Discard(len(buf)) + return copy(buf, b), nil +} diff --git a/util/bufiox/bufiox_test.go b/util/bufiox/bufiox_test.go new file mode 100644 index 0000000000000..727bb36997a6d --- /dev/null +++ b/util/bufiox/bufiox_test.go @@ -0,0 +1,98 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package bufiox + +import ( + "bufio" + "bytes" + "io" + "testing" +) + +func TestReadFull(t *testing.T) { + data := []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08} + br := bufio.NewReader(bytes.NewReader(data)) + + var buf [5]byte + n, err := ReadFull(br, buf[:]) + if err != nil { + t.Fatalf("ReadFull: %v", err) + } + if n != len(buf) { + t.Fatalf("n = %d, want %d", n, len(buf)) + } + if want := [5]byte{0x01, 0x02, 0x03, 0x04, 0x05}; buf != want { + t.Fatalf("buf = %v, want %v", buf, want) + } + + // Remaining bytes should still be readable. + var rest [3]byte + n, err = ReadFull(br, rest[:]) + if err != nil { + t.Fatalf("ReadFull rest: %v", err) + } + if n != len(rest) { + t.Fatalf("rest n = %d, want %d", n, len(rest)) + } + if want := [3]byte{0x06, 0x07, 0x08}; rest != want { + t.Fatalf("rest = %v, want %v", rest, want) + } +} + +func TestReadFullShort(t *testing.T) { + data := []byte{0x01, 0x02} + br := bufio.NewReader(bytes.NewReader(data)) + + var buf [5]byte + _, err := ReadFull(br, buf[:]) + if err != io.ErrUnexpectedEOF { + t.Fatalf("err = %v, want %v", err, io.ErrUnexpectedEOF) + } +} + +func TestReadFullEmpty(t *testing.T) { + br := bufio.NewReader(bytes.NewReader(nil)) + + var buf [1]byte + _, err := ReadFull(br, buf[:]) + if err != io.EOF { + t.Fatalf("err = %v, want %v", err, io.EOF) + } +} + +func TestReadFullZeroAllocs(t *testing.T) { + data := make([]byte, 64) + rd := bytes.NewReader(data) + br := bufio.NewReader(rd) + + var buf [32]byte + got := testing.AllocsPerRun(1000, func() { + rd.Reset(data) + br.Reset(rd) + _, err := ReadFull(br, buf[:]) + if err != nil { + t.Fatalf("ReadFull: %v", err) + } + }) + if got != 0 { + t.Fatalf("ReadFull allocs = %f, want 0", got) + } +} + +type nopReader struct{} + +func (nopReader) Read(p []byte) (int, error) { return len(p), nil } + +func BenchmarkReadFull(b *testing.B) { + br := bufio.NewReader(nopReader{}) + var buf [32]byte + b.ReportAllocs() + b.ResetTimer() + for range b.N { + _, err := ReadFull(br, buf[:]) + if err != nil { + b.Fatal(err) + } + } +} diff --git a/util/clientmetric/clientmetric.go b/util/clientmetric/clientmetric.go index b67cbbd39aa1e..98068b9faf693 100644 --- a/util/clientmetric/clientmetric.go +++ b/util/clientmetric/clientmetric.go @@ -22,6 +22,7 @@ import ( "tailscale.com/feature/buildfeatures" "tailscale.com/util/set" + "tailscale.com/util/testenv" ) var ( @@ -452,6 +453,24 @@ func (b *deltaEncBuf) writeHexVarint(v int64) { b.buf.Write(hexBuf) } +// ResetForTest resets all client metric values to zero. +// It panics if not in a test or if called from a parallel test. +func ResetForTest(t testenv.TB) { + if !testenv.InTest() { + panic("clientmetric.ResetForTest called outside a test") + } + if testenv.InParallelTest(t) { + panic("clientmetric.ResetForTest called from a parallel test") + } + mu.Lock() + defer mu.Unlock() + for _, m := range metrics { + if m.v != nil { + atomic.StoreInt64(m.v, 0) + } + } +} + var TestHooks testHooks type testHooks struct{} diff --git a/util/clientmetric/omit.go b/util/clientmetric/omit.go index 725b18fe48d3c..74018f12ac154 100644 --- a/util/clientmetric/omit.go +++ b/util/clientmetric/omit.go @@ -26,6 +26,9 @@ func WritePrometheusExpositionFormat(any) {} var zeroMetric Metric -func NewCounter(string) *Metric { return &zeroMetric } -func NewGauge(string) *Metric { return &zeroMetric } -func NewAggregateCounter(string) *Metric { return &zeroMetric } +func NewCounter(string) *Metric { return &zeroMetric } +func NewGauge(string) *Metric { return &zeroMetric } +func NewAggregateCounter(string) *Metric { return &zeroMetric } +func NewCounterFunc(string, func() int64) *Metric { return &zeroMetric } + +func ResetForTest(any) {} diff --git a/util/cmpver/version_test.go b/util/cmpver/version_test.go index 5688aa037927b..b3ab1b0289211 100644 --- a/util/cmpver/version_test.go +++ b/util/cmpver/version_test.go @@ -16,77 +16,77 @@ func TestCompare(t *testing.T) { want int }{ { - name: "both empty", + name: "both-empty", want: 0, }, { - name: "v1 empty", + name: "v1-empty", v2: "1.2.3", want: -1, }, { - name: "v2 empty", + name: "v2-empty", v1: "1.2.3", want: 1, }, { - name: "semver major", + name: "semver-major", v1: "2.0.0", v2: "1.9.9", want: 1, }, { - name: "semver major", + name: "semver-major", v1: "2.0.0", v2: "1.9.9", want: 1, }, { - name: "semver minor", + name: "semver-minor", v1: "1.9.0", v2: "1.8.9", want: 1, }, { - name: "semver patch", + name: "semver-patch", v1: "1.9.9", v2: "1.9.8", want: 1, }, { - name: "semver equal", + name: "semver-equal", v1: "1.9.8", v2: "1.9.8", want: 0, }, { - name: "tailscale major", + name: "tailscale-major", v1: "1.0-0", v2: "0.97-105", want: 1, }, { - name: "tailscale minor", + name: "tailscale-minor", v1: "0.98-0", v2: "0.97-105", want: 1, }, { - name: "tailscale patch", + name: "tailscale-patch", v1: "0.97-120", v2: "0.97-105", want: 1, }, { - name: "tailscale equal", + name: "tailscale-equal", v1: "0.97-105", v2: "0.97-105", want: 0, }, { - name: "tailscale weird extra field", + name: "tailscale-weird-extra-field", v1: "0.96.1-0", // more fields == larger v2: "0.96-105", want: 1, @@ -96,7 +96,7 @@ func TestCompare(t *testing.T) { // of strconv.ParseUint with these characters would have lead us to // panic. We're now only looking at ascii numbers, so test these are // compared as text. - name: "only ascii numbers", + name: "only-ascii-numbers", v1: "ÛąÛą", // 2x EXTENDED ARABIC-INDIC DIGIT ONE v2: "Û˛", // 1x EXTENDED ARABIC-INDIC DIGIT TWO want: -1, @@ -104,55 +104,55 @@ func TestCompare(t *testing.T) { // A few specific OS version tests below. { - name: "windows version", + name: "windows-version", v1: "10.0.19045.3324", v2: "10.0.18362", want: 1, }, { - name: "windows 11 is everything above 10.0.22000", + name: "windows-11-above-10_0_22000", v1: "10.0.22631.2262", v2: "10.0.22000", want: 1, }, { - name: "android short version", + name: "android-short-version", v1: "10", v2: "7", want: 1, }, { - name: "android longer version", + name: "android-longer-version", v1: "7.1.2", v2: "7", want: 1, }, { - name: "iOS version", + name: "iOS-version", v1: "15.6.1", v2: "15.6", want: 1, }, { - name: "Linux short kernel version", + name: "linux-short-kernel-version", v1: "4.4.302+", v2: "4.0", want: 1, }, { - name: "Linux long kernel version", + name: "linux-long-kernel-version", v1: "4.14.255-311-248.529.amzn2.x86_64", v2: "4.0", want: 1, }, { - name: "FreeBSD version", + name: "freebsd-version", v1: "14.0-CURRENT", v2: "14", want: 1, }, { - name: "Synology version", + name: "synology-version", v1: "Synology 6.2.4; kernel=3.10.105", v2: "Synology 6", want: 1, diff --git a/util/cstruct/cstruct.go b/util/cstruct/cstruct.go deleted file mode 100644 index afb0150bb1e77..0000000000000 --- a/util/cstruct/cstruct.go +++ /dev/null @@ -1,177 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -// Package cstruct provides a helper for decoding binary data that is in the -// form of a padded C structure. -package cstruct - -import ( - "encoding/binary" - "errors" - "io" -) - -// Size of a pointer-typed value, in bits -const pointerSize = 32 << (^uintptr(0) >> 63) - -// We assume that non-64-bit platforms are 32-bit; we don't expect Go to run on -// a 16- or 8-bit architecture any time soon. -const is64Bit = pointerSize == 64 - -// Decoder reads and decodes padded fields from a slice of bytes. All fields -// are decoded with native endianness. -// -// Methods of a Decoder do not return errors, but rather store any error within -// the Decoder. The first error can be obtained via the Err method; after the -// first error, methods will return the zero value for their type. -type Decoder struct { - b []byte - off int - err error - dbuf [8]byte // for decoding -} - -// NewDecoder creates a Decoder from a byte slice. -func NewDecoder(b []byte) *Decoder { - return &Decoder{b: b} -} - -var errUnsupportedSize = errors.New("unsupported size") - -func padBytes(offset, size int) int { - if offset == 0 || size == 1 { - return 0 - } - remainder := offset % size - return size - remainder -} - -func (d *Decoder) getField(b []byte) error { - size := len(b) - - // We only support fields that are multiples of 2 (or 1-sized) - if size != 1 && size&1 == 1 { - return errUnsupportedSize - } - - // Fields are aligned to their size - padBytes := padBytes(d.off, size) - if d.off+size+padBytes > len(d.b) { - return io.EOF - } - d.off += padBytes - - copy(b, d.b[d.off:d.off+size]) - d.off += size - return nil -} - -// Err returns the first error that was encountered by this Decoder. -func (d *Decoder) Err() error { - return d.err -} - -// Offset returns the current read offset for data in the buffer. -func (d *Decoder) Offset() int { - return d.off -} - -// Byte returns a single byte from the buffer. -func (d *Decoder) Byte() byte { - if d.err != nil { - return 0 - } - - if err := d.getField(d.dbuf[0:1]); err != nil { - d.err = err - return 0 - } - return d.dbuf[0] -} - -// Byte returns a number of bytes from the buffer based on the size of the -// input slice. No padding is applied. -// -// If an error is encountered or this Decoder has previously encountered an -// error, no changes are made to the provided buffer. -func (d *Decoder) Bytes(b []byte) { - if d.err != nil { - return - } - - // No padding for byte slices - size := len(b) - if d.off+size >= len(d.b) { - d.err = io.EOF - return - } - copy(b, d.b[d.off:d.off+size]) - d.off += size -} - -// Uint16 returns a uint16 decoded from the buffer. -func (d *Decoder) Uint16() uint16 { - if d.err != nil { - return 0 - } - - if err := d.getField(d.dbuf[0:2]); err != nil { - d.err = err - return 0 - } - return binary.NativeEndian.Uint16(d.dbuf[0:2]) -} - -// Uint32 returns a uint32 decoded from the buffer. -func (d *Decoder) Uint32() uint32 { - if d.err != nil { - return 0 - } - - if err := d.getField(d.dbuf[0:4]); err != nil { - d.err = err - return 0 - } - return binary.NativeEndian.Uint32(d.dbuf[0:4]) -} - -// Uint64 returns a uint64 decoded from the buffer. -func (d *Decoder) Uint64() uint64 { - if d.err != nil { - return 0 - } - - if err := d.getField(d.dbuf[0:8]); err != nil { - d.err = err - return 0 - } - return binary.NativeEndian.Uint64(d.dbuf[0:8]) -} - -// Uintptr returns a uintptr decoded from the buffer. -func (d *Decoder) Uintptr() uintptr { - if d.err != nil { - return 0 - } - - if is64Bit { - return uintptr(d.Uint64()) - } else { - return uintptr(d.Uint32()) - } -} - -// Int16 returns a int16 decoded from the buffer. -func (d *Decoder) Int16() int16 { - return int16(d.Uint16()) -} - -// Int32 returns a int32 decoded from the buffer. -func (d *Decoder) Int32() int32 { - return int32(d.Uint32()) -} - -// Int64 returns a int64 decoded from the buffer. -func (d *Decoder) Int64() int64 { - return int64(d.Uint64()) -} diff --git a/util/cstruct/cstruct_example_test.go b/util/cstruct/cstruct_example_test.go deleted file mode 100644 index a665abe355f6a..0000000000000 --- a/util/cstruct/cstruct_example_test.go +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -// Only built on 64-bit platforms to avoid complexity - -//go:build amd64 || arm64 || mips64le || ppc64le || riscv64 - -package cstruct - -import "fmt" - -// This test provides a semi-realistic example of how you can -// use this package to decode a C structure. -func ExampleDecoder() { - // Our example C structure: - // struct mystruct { - // char *p; - // char c; - // /* implicit: char _pad[3]; */ - // int x; - // }; - // - // The Go structure definition: - type myStruct struct { - Ptr uintptr - Ch byte - Intval uint32 - } - - // Our "in-memory" version of the above structure - buf := []byte{ - 1, 2, 3, 4, 0, 0, 0, 0, // ptr - 5, // ch - 99, 99, 99, // padding - 78, 6, 0, 0, // x - } - d := NewDecoder(buf) - - // Decode the structure; if one of these function returns an error, - // then subsequent decoder functions will return the zero value. - var x myStruct - x.Ptr = d.Uintptr() - x.Ch = d.Byte() - x.Intval = d.Uint32() - - // Note that per the Go language spec: - // [...] when evaluating the operands of an expression, assignment, - // or return statement, all function calls, method calls, and - // (channel) communication operations are evaluated in lexical - // left-to-right order - // - // Since each field is assigned via a function call, one could use the - // following snippet to decode the struct. - // x := myStruct{ - // Ptr: d.Uintptr(), - // Ch: d.Byte(), - // Intval: d.Uint32(), - // } - // - // However, this means that reordering the fields in the initialization - // statement–normally a semantically identical operation–would change - // the way the structure is parsed. Thus we do it as above with - // explicit ordering. - - // After finishing with the decoder, check errors - if err := d.Err(); err != nil { - panic(err) - } - - // Print the decoder offset and structure - fmt.Printf("off=%d struct=%#v\n", d.Offset(), x) - // Output: off=16 struct=cstruct.myStruct{Ptr:0x4030201, Ch:0x5, Intval:0x64e} -} diff --git a/util/cstruct/cstruct_test.go b/util/cstruct/cstruct_test.go deleted file mode 100644 index 95d4876ca9256..0000000000000 --- a/util/cstruct/cstruct_test.go +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -package cstruct - -import ( - "errors" - "fmt" - "io" - "testing" -) - -func TestPadBytes(t *testing.T) { - testCases := []struct { - offset int - size int - want int - }{ - // No padding at beginning of structure - {0, 1, 0}, - {0, 2, 0}, - {0, 4, 0}, - {0, 8, 0}, - - // No padding for single bytes - {1, 1, 0}, - - // Single byte padding - {1, 2, 1}, - {3, 4, 1}, - - // Multi-byte padding - {1, 4, 3}, - {2, 8, 6}, - } - for _, tc := range testCases { - t.Run(fmt.Sprintf("%d_%d_%d", tc.offset, tc.size, tc.want), func(t *testing.T) { - got := padBytes(tc.offset, tc.size) - if got != tc.want { - t.Errorf("got=%d; want=%d", got, tc.want) - } - }) - } -} - -func TestDecoder(t *testing.T) { - t.Run("UnsignedTypes", func(t *testing.T) { - dec := func(n int) *Decoder { - buf := make([]byte, n) - buf[0] = 1 - - d := NewDecoder(buf) - - // Use t.Cleanup to perform an assertion on this - // decoder after the test code is finished with it. - t.Cleanup(func() { - if err := d.Err(); err != nil { - t.Fatal(err) - } - }) - return d - } - if got := dec(2).Uint16(); got != 1 { - t.Errorf("uint16: got=%d; want=1", got) - } - if got := dec(4).Uint32(); got != 1 { - t.Errorf("uint32: got=%d; want=1", got) - } - if got := dec(8).Uint64(); got != 1 { - t.Errorf("uint64: got=%d; want=1", got) - } - if got := dec(pointerSize / 8).Uintptr(); got != 1 { - t.Errorf("uintptr: got=%d; want=1", got) - } - }) - - t.Run("SignedTypes", func(t *testing.T) { - dec := func(n int) *Decoder { - // Make a buffer of the exact size that consists of 0xff bytes - buf := make([]byte, n) - for i := range n { - buf[i] = 0xff - } - - d := NewDecoder(buf) - - // Use t.Cleanup to perform an assertion on this - // decoder after the test code is finished with it. - t.Cleanup(func() { - if err := d.Err(); err != nil { - t.Fatal(err) - } - }) - return d - } - if got := dec(2).Int16(); got != -1 { - t.Errorf("int16: got=%d; want=-1", got) - } - if got := dec(4).Int32(); got != -1 { - t.Errorf("int32: got=%d; want=-1", got) - } - if got := dec(8).Int64(); got != -1 { - t.Errorf("int64: got=%d; want=-1", got) - } - }) - - t.Run("InsufficientData", func(t *testing.T) { - dec := func(n int) *Decoder { - // Make a buffer that's too small and contains arbitrary bytes - buf := make([]byte, n-1) - for i := range n - 1 { - buf[i] = 0xAD - } - - // Use t.Cleanup to perform an assertion on this - // decoder after the test code is finished with it. - d := NewDecoder(buf) - t.Cleanup(func() { - if err := d.Err(); err == nil || !errors.Is(err, io.EOF) { - t.Errorf("(n=%d) expected io.EOF; got=%v", n, err) - } - }) - return d - } - - dec(2).Uint16() - dec(4).Uint32() - dec(8).Uint64() - dec(pointerSize / 8).Uintptr() - - dec(2).Int16() - dec(4).Int32() - dec(8).Int64() - }) - - t.Run("Bytes", func(t *testing.T) { - d := NewDecoder([]byte("hello worldasdf")) - t.Cleanup(func() { - if err := d.Err(); err != nil { - t.Fatal(err) - } - }) - - buf := make([]byte, 11) - d.Bytes(buf) - if got := string(buf); got != "hello world" { - t.Errorf("bytes: got=%q; want=%q", got, "hello world") - } - }) -} diff --git a/util/deephash/deephash_test.go b/util/deephash/deephash_test.go index c50d70bc6ed7f..a82203d503591 100644 --- a/util/deephash/deephash_test.go +++ b/util/deephash/deephash_test.go @@ -24,7 +24,6 @@ import ( "go4.org/netipx" "tailscale.com/tailcfg" "tailscale.com/types/key" - "tailscale.com/types/ptr" "tailscale.com/util/deephash/testtype" "tailscale.com/util/hashx" "tailscale.com/version" @@ -60,7 +59,7 @@ func TestHash(t *testing.T) { I16 int16 I32 int32 I64 int64 - I int + Int int U8 uint8 U16 uint16 U32 uint32 @@ -93,7 +92,7 @@ func TestHash(t *testing.T) { {in: tuple{scalars{I16: math.MinInt16}, scalars{I16: math.MinInt16 / 2}}, wantEq: false}, {in: tuple{scalars{I32: math.MinInt32}, scalars{I32: math.MinInt32 / 2}}, wantEq: false}, {in: tuple{scalars{I64: math.MinInt64}, scalars{I64: math.MinInt64 / 2}}, wantEq: false}, - {in: tuple{scalars{I: -1234}, scalars{I: -1234 / 2}}, wantEq: false}, + {in: tuple{scalars{Int: -1234}, scalars{Int: -1234 / 2}}, wantEq: false}, {in: tuple{scalars{U8: math.MaxUint8}, scalars{U8: math.MaxUint8 / 2}}, wantEq: false}, {in: tuple{scalars{U16: math.MaxUint16}, scalars{U16: math.MaxUint16 / 2}}, wantEq: false}, {in: tuple{scalars{U32: math.MaxUint32}, scalars{U32: math.MaxUint32 / 2}}, wantEq: false}, @@ -361,17 +360,17 @@ func TestGetTypeHasher(t *testing.T) { out32: "\x01\x04\x00\x00\x00\x00\x00\x00\x00\x01\x03\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00*\v\x00\x00\x00\x00\x00\x00\x0010.1.3.4/32\v\x00\x00\x00\x00\x00\x00\x0010.0.0.0/24\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\n\x00\x00\x00\x00\x00\x00\x001.2.3.4/32\x01 \x00\x00\x00\x01\x00\x02\x00\x01\x04\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x04\x00\x00\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04!\x01\x01\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00foo\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\v\x00\x00\x00\x00\x00\x00\x00foooooooooo\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\f\x00\x00\x00\x00\x00\x00\x00baaaaaarrrrr\x00\x01\x00\x02\x00\x00\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\v\x00\x00\x00\x00\x00\x00\x00foooooooooo\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\f\x00\x00\x00\x00\x00\x00\x00baaaaaarrrrr\x00\x01\x00\x02\x00\x00\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\v\x00\x00\x00\x00\x00\x00\x00foooooooooo\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\f\x00\x00\x00\x00\x00\x00\x00baaaaaarrrrr\x00\x01\x00\x02\x00\x00\x00", }, { - name: "netip.Addr", + name: "netip-Addr", val: netip.MustParseAddr("fe80::123%foo"), out: u64(16+3) + u64(0x80fe) + u64(0x2301<<48) + "foo", }, { - name: "ptr-netip.Addr", + name: "ptr-netip-Addr", val: &someIP, out: u8(1) + u64(4) + u32(0x04030201), }, { - name: "ptr-nil-netip.Addr", + name: "ptr-nil-netip-Addr", val: (*netip.Addr)(nil), out: "\x00", }, @@ -382,7 +381,7 @@ func TestGetTypeHasher(t *testing.T) { }, { name: "time_ptr", // addressable, as opposed to "time" test above - val: ptr.To(time.Unix(1234, 5678).In(time.UTC)), + val: new(time.Unix(1234, 5678).In(time.UTC)), out: u8(1) + u64(1234) + u32(5678) + u32(0), }, { @@ -412,7 +411,7 @@ func TestGetTypeHasher(t *testing.T) { }, { name: "array_ptr_memhash", - val: ptr.To([4]byte{1, 2, 3, 4}), + val: new([4]byte{1, 2, 3, 4}), out: "\x01\x01\x02\x03\x04", }, { @@ -470,7 +469,7 @@ func TestGetTypeHasher(t *testing.T) { out: "\x01\x01\x00\x00\x00\x02\x00\x00\x00\x03\x04\x00\x00\x00\x05\x00\x00\x00\x06\x00\x00\x00\a\b\x00\x00\x00", }, { - name: "tailcfg.Node", + name: "tailcfg-Node", val: &tailcfg.Node{}, out: "ANY", // magic value; just check it doesn't fail to hash out32: "ANY", @@ -640,7 +639,7 @@ var filterRules = []tailcfg.FilterRule{ SrcIPs: []string{"*", "10.1.3.4/32", "10.0.0.0/24"}, DstPorts: []tailcfg.NetPortRange{{ IP: "1.2.3.4/32", - Bits: ptr.To(32), + Bits: new(32), Ports: tailcfg.PortRange{First: 1, Last: 2}, }}, IPProto: []int{1, 2, 3, 4}, @@ -823,7 +822,7 @@ func TestHashThroughView(t *testing.T) { SSHPolicy: &sshPolicyOut{ Rules: []tailcfg.SSHRuleView{ (&tailcfg.SSHRule{ - RuleExpires: ptr.To(time.Unix(123, 0)), + RuleExpires: new(time.Unix(123, 0)), }).View(), }, }, diff --git a/util/dnsname/dnsname.go b/util/dnsname/dnsname.go index 263c376aac674..cf1ae62000956 100644 --- a/util/dnsname/dnsname.go +++ b/util/dnsname/dnsname.go @@ -234,7 +234,7 @@ func ValidHostname(hostname string) error { return err } - for _, label := range strings.Split(fqdn.WithoutTrailingDot(), ".") { + for label := range strings.SplitSeq(fqdn.WithoutTrailingDot(), ".") { if err := ValidLabel(label); err != nil { return err } diff --git a/util/eventbus/eventbustest/eventbustest_test.go b/util/eventbus/eventbustest/eventbustest_test.go index 810312fcb411a..3c8b5aee86d8f 100644 --- a/util/eventbus/eventbustest/eventbustest_test.go +++ b/util/eventbus/eventbustest/eventbustest_test.go @@ -36,17 +36,17 @@ func TestExpectFilter(t *testing.T) { wantErr string // if non-empty, an error is expected containing this text }{ { - name: "single event", + name: "single-event", events: []int{42}, expectFunc: eventbustest.Type[EventFoo](), }, { - name: "multiple events, single expectation", + name: "multiple-events-single-expectation", events: []int{42, 1, 2, 3, 4, 5}, expectFunc: eventbustest.Type[EventFoo](), }, { - name: "filter on event with function", + name: "filter-on-event-with-function", events: []int{24, 42}, expectFunc: func(event EventFoo) (bool, error) { if event.Value == 42 { @@ -77,7 +77,7 @@ func TestExpectFilter(t *testing.T) { wantErr: "value > 10", }, { - name: "first event has to be func", + name: "first-event-has-to-be-func", events: []int{24, 42}, expectFunc: func(event EventFoo) (bool, error) { if event.Value != 42 { @@ -99,7 +99,7 @@ func TestExpectFilter(t *testing.T) { wantErr: "wrong result (-got, +want)", }, { - name: "no events", + name: "no-events", events: []int{}, expectFunc: func(event EventFoo) (bool, error) { return true, nil @@ -151,37 +151,37 @@ func TestExpectEvents(t *testing.T) { wantErr bool }{ { - name: "No expectations", + name: "no-expectations", events: []any{EventFoo{}}, expectEvents: []any{}, wantErr: true, }, { - name: "One event", + name: "one-event", events: []any{EventFoo{}}, expectEvents: []any{eventbustest.Type[EventFoo]()}, wantErr: false, }, { - name: "Two events", + name: "two-events", events: []any{EventFoo{}, EventBar{}}, expectEvents: []any{eventbustest.Type[EventFoo](), eventbustest.Type[EventBar]()}, wantErr: false, }, { - name: "Two expected events with another in the middle", + name: "two-expected-events-with-another-in-middle", events: []any{EventFoo{}, EventBaz{}, EventBar{}}, expectEvents: []any{eventbustest.Type[EventFoo](), eventbustest.Type[EventBar]()}, wantErr: false, }, { - name: "Missing event", + name: "missing-event", events: []any{EventFoo{}, EventBaz{}}, expectEvents: []any{eventbustest.Type[EventFoo](), eventbustest.Type[EventBar]()}, wantErr: true, }, { - name: "One event with specific value", + name: "one-event-with-specific-value", events: []any{EventFoo{42}}, expectEvents: []any{ func(ev EventFoo) (bool, error) { @@ -194,7 +194,7 @@ func TestExpectEvents(t *testing.T) { wantErr: false, }, { - name: "Two event with one specific value", + name: "two-events-with-one-specific-value", events: []any{EventFoo{43}, EventFoo{42}}, expectEvents: []any{ func(ev EventFoo) (bool, error) { @@ -207,7 +207,7 @@ func TestExpectEvents(t *testing.T) { wantErr: false, }, { - name: "One event with wrong value", + name: "one-event-with-wrong-value", events: []any{EventFoo{43}}, expectEvents: []any{ func(ev EventFoo) (bool, error) { @@ -220,7 +220,7 @@ func TestExpectEvents(t *testing.T) { wantErr: true, }, { - name: "Two events with specific values", + name: "two-events-with-specific-values", events: []any{EventFoo{42}, EventFoo{42}, EventBar{"42"}}, expectEvents: []any{ func(ev EventFoo) (bool, error) { @@ -283,37 +283,37 @@ func TestExpectExactlyEventsFilter(t *testing.T) { wantErr bool }{ { - name: "No expectations", + name: "no-expectations", events: []any{EventFoo{}}, expectEvents: []any{}, wantErr: true, }, { - name: "One event", + name: "one-event", events: []any{EventFoo{}}, expectEvents: []any{eventbustest.Type[EventFoo]()}, wantErr: false, }, { - name: "Two events", + name: "two-events", events: []any{EventFoo{}, EventBar{}}, expectEvents: []any{eventbustest.Type[EventFoo](), eventbustest.Type[EventBar]()}, wantErr: false, }, { - name: "Two expected events with another in the middle", + name: "two-expected-events-with-another-in-middle", events: []any{EventFoo{}, EventBaz{}, EventBar{}}, expectEvents: []any{eventbustest.Type[EventFoo](), eventbustest.Type[EventBar]()}, wantErr: true, }, { - name: "Missing event", + name: "missing-event", events: []any{EventFoo{}, EventBaz{}}, expectEvents: []any{eventbustest.Type[EventFoo](), eventbustest.Type[EventBar]()}, wantErr: true, }, { - name: "One event with value", + name: "one-event-with-value", events: []any{EventFoo{42}}, expectEvents: []any{ func(ev EventFoo) (bool, error) { @@ -326,7 +326,7 @@ func TestExpectExactlyEventsFilter(t *testing.T) { wantErr: false, }, { - name: "Two event with one specific value", + name: "two-events-with-one-specific-value", events: []any{EventFoo{43}, EventFoo{42}}, expectEvents: []any{ func(ev EventFoo) (bool, error) { @@ -339,7 +339,7 @@ func TestExpectExactlyEventsFilter(t *testing.T) { wantErr: true, }, { - name: "One event with wrong value", + name: "one-event-with-wrong-value", events: []any{EventFoo{43}}, expectEvents: []any{ func(ev EventFoo) (bool, error) { @@ -352,7 +352,7 @@ func TestExpectExactlyEventsFilter(t *testing.T) { wantErr: true, }, { - name: "Two events with specific values", + name: "two-events-with-specific-values", events: []any{EventFoo{42}, EventFoo{42}, EventBar{"42"}}, expectEvents: []any{ func(ev EventFoo) (bool, error) { diff --git a/util/expvarx/expvarx.go b/util/expvarx/expvarx.go deleted file mode 100644 index 6dc2379b961a5..0000000000000 --- a/util/expvarx/expvarx.go +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -// Package expvarx provides some extensions to the [expvar] package. -package expvarx - -import ( - "encoding/json" - "expvar" - "time" - - "tailscale.com/syncs" - "tailscale.com/types/lazy" -) - -// SafeFunc is a wrapper around [expvar.Func] that guards against unbounded call -// time and ensures that only a single call is in progress at any given time. -type SafeFunc struct { - f expvar.Func - limit time.Duration - onSlow func(time.Duration, any) - - mu syncs.Mutex - inflight *lazy.SyncValue[any] -} - -// NewSafeFunc returns a new SafeFunc that wraps f. -// If f takes longer than limit to execute then Value calls return nil. -// If onSlow is non-nil, it is called when f takes longer than limit to execute. -// onSlow is called with the duration of the slow call and the final computed -// value. -func NewSafeFunc(f expvar.Func, limit time.Duration, onSlow func(time.Duration, any)) *SafeFunc { - return &SafeFunc{f: f, limit: limit, onSlow: onSlow} -} - -// Value acts similarly to [expvar.Func.Value], but if the underlying function -// takes longer than the configured limit, all callers will receive nil until -// the underlying operation completes. On completion of the underlying -// operation, the onSlow callback is called if set. -func (s *SafeFunc) Value() any { - s.mu.Lock() - - if s.inflight == nil { - s.inflight = new(lazy.SyncValue[any]) - } - var inflight = s.inflight - s.mu.Unlock() - - // inflight ensures that only a single work routine is spawned at any given - // time, but if the routine takes too long inflight is populated with a nil - // result. The long running computed value is lost forever. - return inflight.Get(func() any { - start := time.Now() - result := make(chan any, 1) - - // work is spawned in routine so that the caller can timeout. - go func() { - // Allow new work to be started after this work completes - defer func() { - s.mu.Lock() - s.inflight = nil - s.mu.Unlock() - - }() - - v := s.f.Value() - result <- v - }() - - select { - case v := <-result: - return v - case <-time.After(s.limit): - if s.onSlow != nil { - go func() { - s.onSlow(time.Since(start), <-result) - }() - } - return nil - } - }) -} - -// String implements stringer in the same pattern as [expvar.Func], calling -// Value and serializing the result as JSON, ignoring errors. -func (s *SafeFunc) String() string { - v, _ := json.Marshal(s.Value()) - return string(v) -} diff --git a/util/expvarx/expvarx_test.go b/util/expvarx/expvarx_test.go deleted file mode 100644 index f8d2139d3ecb1..0000000000000 --- a/util/expvarx/expvarx_test.go +++ /dev/null @@ -1,141 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -package expvarx - -import ( - "expvar" - "fmt" - "sync" - "sync/atomic" - "testing" - "testing/synctest" - "time" -) - -func ExampleNewSafeFunc() { - // An artificial blocker to emulate a slow operation. - blocker := make(chan struct{}) - - // limit is the amount of time a call can take before Value returns nil. No - // new calls to the unsafe func will be started until the slow call - // completes, at which point onSlow will be called. - limit := time.Millisecond - - // onSlow is called with the final call duration and the final value in the - // event a slow call. - onSlow := func(d time.Duration, v any) { - _ = d // d contains the time the call took - _ = v // v contains the final value computed by the slow call - fmt.Println("slow call!") - } - - // An unsafe expvar.Func that blocks on the blocker channel. - unsafeFunc := expvar.Func(func() any { - for range blocker { - } - return "hello world" - }) - - // f implements the same interface as expvar.Func, but returns nil values - // when the unsafe func is too slow. - f := NewSafeFunc(unsafeFunc, limit, onSlow) - - fmt.Println(f.Value()) - fmt.Println(f.Value()) - close(blocker) - time.Sleep(time.Millisecond) - fmt.Println(f.Value()) - // Output: - // - // slow call! - // hello world -} - -func TestSafeFuncHappyPath(t *testing.T) { - synctest.Test(t, func(t *testing.T) { - var count int - f := NewSafeFunc(expvar.Func(func() any { - count++ - return count - }), time.Second, nil) - - if got, want := f.Value(), 1; got != want { - t.Errorf("got %v, want %v", got, want) - } - time.Sleep(5 * time.Second) // (fake time in synctest) - if got, want := f.Value(), 2; got != want { - t.Errorf("got %v, want %v", got, want) - } - }) -} - -func TestSafeFuncSlow(t *testing.T) { - var count int - blocker := make(chan struct{}) - var wg sync.WaitGroup - wg.Add(1) - f := NewSafeFunc(expvar.Func(func() any { - defer wg.Done() - count++ - <-blocker - return count - }), time.Millisecond, nil) - - if got := f.Value(); got != nil { - t.Errorf("got %v; want nil", got) - } - if got := f.Value(); got != nil { - t.Errorf("got %v; want nil", got) - } - - close(blocker) - wg.Wait() - - if count != 1 { - t.Errorf("got count=%d; want 1", count) - } -} - -func TestSafeFuncSlowOnSlow(t *testing.T) { - var count int - blocker := make(chan struct{}) - var wg sync.WaitGroup - wg.Add(2) - var slowDuration atomic.Pointer[time.Duration] - var slowCallCount atomic.Int32 - var slowValue atomic.Value - f := NewSafeFunc(expvar.Func(func() any { - defer wg.Done() - count++ - <-blocker - return count - }), time.Millisecond, func(d time.Duration, v any) { - defer wg.Done() - slowDuration.Store(&d) - slowCallCount.Add(1) - slowValue.Store(v) - }) - - for range 10 { - if got := f.Value(); got != nil { - t.Fatalf("got value=%v; want nil", got) - } - } - - close(blocker) - wg.Wait() - - if count != 1 { - t.Errorf("got count=%d; want 1", count) - } - if got, want := *slowDuration.Load(), 1*time.Millisecond; got < want { - t.Errorf("got slowDuration=%v; want at least %d", got, want) - } - if got, want := slowCallCount.Load(), int32(1); got != want { - t.Errorf("got slowCallCount=%d; want %d", got, want) - } - if got, want := slowValue.Load().(int), 1; got != want { - t.Errorf("got slowValue=%d, want %d", got, want) - } -} diff --git a/util/goroutines/goroutines.go b/util/goroutines/goroutines.go index fd0a4dd7eb321..f184fcd6c9e73 100644 --- a/util/goroutines/goroutines.go +++ b/util/goroutines/goroutines.go @@ -52,7 +52,7 @@ func scrubHex(buf []byte) []byte { in[0] = '?' return } - v := []byte(fmt.Sprintf("v%d%%%d", len(saw)+1, u64%8)) + v := fmt.Appendf(nil, "v%d%%%d", len(saw)+1, u64%8) saw[inStr] = v copy(in, v) }) diff --git a/util/hashx/block512_test.go b/util/hashx/block512_test.go index 91d5d9ee67749..03c77eabbecc3 100644 --- a/util/hashx/block512_test.go +++ b/util/hashx/block512_test.go @@ -47,7 +47,7 @@ type hasher interface { func hashSuite(h hasher) { for i := range 10 { - for j := 0; j < 10; j++ { + for range 10 { h.HashUint8(0x01) h.HashUint8(0x23) h.HashUint32(0x456789ab) diff --git a/util/httphdr/httphdr.go b/util/httphdr/httphdr.go index 01e8eddc67ac1..852b3f5c74138 100644 --- a/util/httphdr/httphdr.go +++ b/util/httphdr/httphdr.go @@ -44,7 +44,7 @@ func ParseRange(hdr string) (ranges []Range, ok bool) { hdr = strings.Trim(hdr, ows) // per RFC 7230, section 3.2 units, elems, hasUnits := strings.Cut(hdr, "=") elems = strings.TrimLeft(elems, ","+ows) - for _, elem := range strings.Split(elems, ",") { + for elem := range strings.SplitSeq(elems, ",") { elem = strings.Trim(elem, ows) // per RFC 7230, section 7 switch { case strings.HasPrefix(elem, "-"): // i.e., "-" suffix-length diff --git a/util/httpm/httpm_test.go b/util/httpm/httpm_test.go index 4e7f7b5ab277c..e8342a74f26b6 100644 --- a/util/httpm/httpm_test.go +++ b/util/httpm/httpm_test.go @@ -24,10 +24,17 @@ func TestUsedConsistently(t *testing.T) { t.Skipf("skipping test since .git doesn't exist: %v", err) } + // Open .git/index so Go's test cache tracks it as an input. + // The index file changes on git reset, checkout, pull, etc., + // so the cache is properly invalidated when moving between commits. + if f, err := os.Open(filepath.Join(rootDir, ".git", "index")); err == nil { + f.Close() + } + cmd := exec.Command("git", "grep", "-l", "-F", "http.Method") cmd.Dir = rootDir matches, _ := cmd.Output() - for _, fn := range strings.Split(strings.TrimSpace(string(matches)), "\n") { + for fn := range strings.SplitSeq(strings.TrimSpace(string(matches)), "\n") { switch fn { case "util/httpm/httpm.go", "util/httpm/httpm_test.go": continue diff --git a/util/linuxfw/fake.go b/util/linuxfw/fake.go index deeae87603f8a..b902b93c1a66b 100644 --- a/util/linuxfw/fake.go +++ b/util/linuxfw/fake.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "os" + "slices" "strconv" "strings" ) @@ -62,10 +63,8 @@ func (n *fakeIPTables) Append(table, chain string, args ...string) error { func (n *fakeIPTables) Exists(table, chain string, args ...string) (bool, error) { k := table + "/" + chain if rules, ok := n.n[k]; ok { - for _, rule := range rules { - if rule == strings.Join(args, " ") { - return true, nil - } + if slices.Contains(rules, strings.Join(args, " ")) { + return true, nil } return false, nil } else { diff --git a/util/linuxfw/fake_netfilter.go b/util/linuxfw/fake_netfilter.go index eac5d904cff3a..1ecfc1c39993e 100644 --- a/util/linuxfw/fake_netfilter.go +++ b/util/linuxfw/fake_netfilter.go @@ -19,6 +19,8 @@ type FakeNetfilterRunner struct { TailscaleServiceIP netip.Addr ClusterIP netip.Addr } + // clampedAddrs tracks addresses passed to ClampMSSToPMTU. + clampedAddrs []netip.Addr } // NewFakeNetfilterRunner creates a new FakeNetfilterRunner. @@ -83,7 +85,15 @@ func (f *FakeNetfilterRunner) DNATWithLoadBalancer(origDst netip.Addr, dsts []ne } func (f *FakeNetfilterRunner) EnsureSNATForDst(src, dst netip.Addr) error { return nil } func (f *FakeNetfilterRunner) DNATNonTailscaleTraffic(tun string, dst netip.Addr) error { return nil } -func (f *FakeNetfilterRunner) ClampMSSToPMTU(tun string, addr netip.Addr) error { return nil } +func (f *FakeNetfilterRunner) ClampMSSToPMTU(tun string, addr netip.Addr) error { + f.clampedAddrs = append(f.clampedAddrs, addr) + return nil +} + +// GetClampedAddrs returns the addresses passed to ClampMSSToPMTU. +func (f *FakeNetfilterRunner) GetClampedAddrs() []netip.Addr { + return f.clampedAddrs +} func (f *FakeNetfilterRunner) AddMagicsockPortRule(port uint16, network string) error { return nil } func (f *FakeNetfilterRunner) DelMagicsockPortRule(port uint16, network string) error { return nil } func (f *FakeNetfilterRunner) DeletePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error { @@ -95,3 +105,5 @@ func (f *FakeNetfilterRunner) DeleteSvc(svc, tun string, targetIPs []netip.Addr, func (f *FakeNetfilterRunner) EnsurePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error { return nil } +func (f *FakeNetfilterRunner) AddExternalCGNATRules(mode CGNATMode, tunname string) error { return nil } +func (f *FakeNetfilterRunner) DelExternalCGNATRules(mode CGNATMode, tunname string) error { return nil } diff --git a/util/linuxfw/iptables.go b/util/linuxfw/iptables.go index f054e7abe1718..3bd2c288699e4 100644 --- a/util/linuxfw/iptables.go +++ b/util/linuxfw/iptables.go @@ -21,8 +21,8 @@ import ( func init() { isNotExistError = func(err error) bool { - var e *iptables.Error - return errors.As(err, &e) && e.IsNotExist() + e, ok := errors.AsType[*iptables.Error](err) + return ok && e.IsNotExist() } } diff --git a/util/linuxfw/linuxfw.go b/util/linuxfw/linuxfw.go index 325a5809f8586..ed130a2b1416b 100644 --- a/util/linuxfw/linuxfw.go +++ b/util/linuxfw/linuxfw.go @@ -53,6 +53,13 @@ const ( FirewallModeNfTables FirewallMode = "nftables" ) +type CGNATMode string + +const ( + CGNATModeDrop CGNATMode = "DROP" + CGNATModeReturn CGNATMode = "RETURN" +) + // The following bits are added to packet marks for Tailscale use. // // We tried to pick bits sufficiently out of the way that it's diff --git a/util/linuxfw/linuxfwtest/linuxfwtest.go b/util/linuxfw/linuxfwtest/linuxfwtest.go deleted file mode 100644 index bf1477ad9b994..0000000000000 --- a/util/linuxfw/linuxfwtest/linuxfwtest.go +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -//go:build cgo && linux - -// Package linuxfwtest contains tests for the linuxfw package. Go does not -// support cgo in tests, and we don't want the main package to have a cgo -// dependency, so we put all the tests here and call them from the main package -// in tests intead. -package linuxfwtest - -import ( - "testing" - "unsafe" -) - -/* -#include // socket() -*/ -import "C" - -type SizeInfo struct { - SizeofSocklen uintptr -} - -func TestSizes(t *testing.T, si *SizeInfo) { - want := unsafe.Sizeof(C.socklen_t(0)) - if want != si.SizeofSocklen { - t.Errorf("sockLen has wrong size; want=%d got=%d", want, si.SizeofSocklen) - } -} diff --git a/util/linuxfw/linuxfwtest/linuxfwtest_unsupported.go b/util/linuxfw/linuxfwtest/linuxfwtest_unsupported.go deleted file mode 100644 index ec2d24d3521c9..0000000000000 --- a/util/linuxfw/linuxfwtest/linuxfwtest_unsupported.go +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !cgo || !linux - -package linuxfwtest - -import ( - "testing" -) - -type SizeInfo struct { - SizeofSocklen uintptr -} - -func TestSizes(t *testing.T, si *SizeInfo) { - t.Skip("not supported without cgo") -} diff --git a/util/linuxfw/nftables_for_svcs.go b/util/linuxfw/nftables_for_svcs.go index c2425e2ff285b..35764a2bde5da 100644 --- a/util/linuxfw/nftables_for_svcs.go +++ b/util/linuxfw/nftables_for_svcs.go @@ -236,7 +236,7 @@ func portMapRule(t *nftables.Table, ch *nftables.Chain, tun string, targetIP net // This metadata can then be used to find the rule. // https://github.com/google/nftables/issues/48 func svcPortMapRuleMeta(svcName string, targetIP netip.Addr, pm PortMap) []byte { - return []byte(fmt.Sprintf("svc:%s,targetIP:%s:matchPort:%v,targetPort:%v,proto:%v", svcName, targetIP.String(), pm.MatchPort, pm.TargetPort, pm.Protocol)) + return fmt.Appendf(nil, "svc:%s,targetIP:%s:matchPort:%v,targetPort:%v,proto:%v", svcName, targetIP.String(), pm.MatchPort, pm.TargetPort, pm.Protocol) } func (n *nftablesRunner) findRuleByMetadata(t *nftables.Table, ch *nftables.Chain, meta []byte) (*nftables.Rule, error) { @@ -305,5 +305,5 @@ func protoFromString(s string) (uint8, error) { // This metadata can then be used to find the rule. // https://github.com/google/nftables/issues/48 func svcRuleMeta(svcName string, origDst, dst netip.Addr) []byte { - return []byte(fmt.Sprintf("svc:%s,VIP:%s,ClusterIP:%s", svcName, origDst.String(), dst.String())) + return fmt.Appendf(nil, "svc:%s,VIP:%s,ClusterIP:%s", svcName, origDst.String(), dst.String()) } diff --git a/util/linuxfw/nftables_runner_test.go b/util/linuxfw/nftables_runner_test.go index 8299a9cbd72da..65df7718e10e8 100644 --- a/util/linuxfw/nftables_runner_test.go +++ b/util/linuxfw/nftables_runner_test.go @@ -10,7 +10,6 @@ import ( "errors" "fmt" "net/netip" - "os" "runtime" "slices" "strings" @@ -522,10 +521,7 @@ func TestAddMatchSubnetRouteMarkRuleAccept(t *testing.T) { func newSysConn(t *testing.T) *nftables.Conn { t.Helper() - if os.Geteuid() != 0 { - t.Skip(t.Name(), " requires privileges to create a namespace in order to run") - return nil - } + tstest.RequireRoot(t) runtime.LockOSThread() @@ -637,7 +633,7 @@ func TestAddAndDelNetfilterChains(t *testing.T) { func getTsChains( conn *nftables.Conn, proto nftables.TableFamily) (*nftables.Chain, *nftables.Chain, *nftables.Chain, error) { - chains, err := conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) + chains, err := conn.ListChainsOfTableFamily(proto) if err != nil { return nil, nil, nil, fmt.Errorf("list chains failed: %w", err) } @@ -662,17 +658,7 @@ func findV4BaseRules( forwChain *nftables.Chain, tunname string) ([]*nftables.Rule, error) { want := []*nftables.Rule{} - rule, err := createRangeRule(inpChain.Table, inpChain, tunname, tsaddr.ChromeOSVMRange(), expr.VerdictReturn) - if err != nil { - return nil, fmt.Errorf("create rule: %w", err) - } - want = append(want, rule) - rule, err = createRangeRule(inpChain.Table, inpChain, tunname, tsaddr.CGNATRange(), expr.VerdictDrop) - if err != nil { - return nil, fmt.Errorf("create rule: %w", err) - } - want = append(want, rule) - rule, err = createDropOutgoingPacketFromCGNATRangeRuleWithTunname(forwChain.Table, forwChain, tunname) + rule, err := createDropOutgoingPacketFromCGNATRangeRuleWithTunname(forwChain.Table, forwChain, tunname) if err != nil { return nil, fmt.Errorf("create rule: %w", err) } @@ -749,7 +735,7 @@ func TestNFTAddAndDelNetfilterBase(t *testing.T) { if err != nil { t.Fatalf("getTsChains() failed: %v", err) } - checkChainRules(t, conn, inputV4, 3) + checkChainRules(t, conn, inputV4, 1) checkChainRules(t, conn, forwardV4, 4) checkChainRules(t, conn, postroutingV4, 0) @@ -767,8 +753,8 @@ func TestNFTAddAndDelNetfilterBase(t *testing.T) { if err != nil { t.Fatalf("getTsChains() failed: %v", err) } - checkChainRules(t, conn, inputV6, 3) - checkChainRules(t, conn, forwardV6, 4) + checkChainRules(t, conn, inputV6, 1) + checkChainRules(t, conn, forwardV6, 3) checkChainRules(t, conn, postroutingV6, 0) _, err = findCommonBaseRules(conn, forwardV6, "testTunn") @@ -787,6 +773,92 @@ func TestNFTAddAndDelNetfilterBase(t *testing.T) { } } +func findCGNATRules( + conn *nftables.Conn, + inpChain *nftables.Chain, + mode CGNATMode, + tunname string, +) error { + want := []*nftables.Rule{} + switch mode { + case CGNATModeDrop: + rule, err := createRangeRule(inpChain.Table, inpChain, tunname, tsaddr.ChromeOSVMRange(), expr.VerdictReturn) + if err != nil { + return fmt.Errorf("create rule: %w", err) + } + want = append(want, rule) + rule, err = createRangeRule(inpChain.Table, inpChain, tunname, tsaddr.CGNATRange(), expr.VerdictDrop) + if err != nil { + return fmt.Errorf("create rule: %w", err) + } + want = append(want, rule) + case CGNATModeReturn: + rule, err := createRangeRule(inpChain.Table, inpChain, tunname, tsaddr.CGNATRange(), expr.VerdictReturn) + if err != nil { + return fmt.Errorf("create rule: %w", err) + } + want = append(want, rule) + default: + return fmt.Errorf("unknown mode %q", mode) + } + for _, rule := range want { + _, err := findRule(conn, rule) + if err != nil { + return fmt.Errorf("find rule: %w", err) + } + } + return nil +} + +func TestNFTAddAndDelCGNATRules(t *testing.T) { + modes := []CGNATMode{CGNATModeDrop, CGNATModeReturn} + for _, mode := range modes { + t.Run(string(mode), func(t *testing.T) { + conn := newSysConn(t) + + runner := newFakeNftablesRunnerWithConn(t, conn, false) + + if err := runner.AddChains(); err != nil { + t.Fatalf("AddChains() failed: %v", err) + } + defer runner.DelChains() + + inputV4, _, _, err := getTsChains(conn, nftables.TableFamilyIPv4) + if err != nil { + t.Fatalf("getTsChains() failed: %v", err) + } + + checkChainRules(t, conn, inputV4, 0) + + tunname := "tun0" + + if err := runner.AddExternalCGNATRules(mode, tunname); err != nil { + t.Fatalf("add rules: %v", err) + } + + switch mode { + case CGNATModeDrop: + checkChainRules(t, conn, inputV4, 2) + case CGNATModeReturn: + checkChainRules(t, conn, inputV4, 1) + default: + t.Fatalf("unknown mode %q", mode) + } + + if err := findCGNATRules(conn, inputV4, mode, tunname); err != nil { + t.Fatalf("find rules: %v", err) + } + + if err := runner.DelExternalCGNATRules(mode, tunname); err != nil { + t.Fatalf("delete rules: %v", err) + } + + // Verify that all the rules have been deleted (0 remaining). + checkChainRules(t, conn, inputV4, 0) + }) + } +} + func findLoopBackRule(conn *nftables.Conn, proto nftables.TableFamily, table *nftables.Table, chain *nftables.Chain, addr netip.Addr) (*nftables.Rule, error) { matchingAddr := addr.AsSlice() saddrExpr, err := newLoadSaddrExpr(proto, 1) @@ -849,16 +921,16 @@ func TestNFTAddAndDelLoopbackRule(t *testing.T) { runner.AddBase("testTunn") defer runner.DelBase() - checkChainRules(t, conn, inputV4, 3) - checkChainRules(t, conn, inputV6, 3) + checkChainRules(t, conn, inputV4, 1) + checkChainRules(t, conn, inputV6, 1) addr := netip.MustParseAddr("192.168.0.2") addrV6 := netip.MustParseAddr("2001:db8::2") runner.AddLoopbackRule(addr) runner.AddLoopbackRule(addrV6) - checkChainRules(t, conn, inputV4, 4) - checkChainRules(t, conn, inputV6, 4) + checkChainRules(t, conn, inputV4, 2) + checkChainRules(t, conn, inputV6, 2) existingLoopBackRule, err := findLoopBackRule(conn, nftables.TableFamilyIPv4, runner.nft4.Filter, inputV4, addr) if err != nil { @@ -881,8 +953,8 @@ func TestNFTAddAndDelLoopbackRule(t *testing.T) { runner.DelLoopbackRule(addr) runner.DelLoopbackRule(addrV6) - checkChainRules(t, conn, inputV4, 3) - checkChainRules(t, conn, inputV6, 3) + checkChainRules(t, conn, inputV4, 1) + checkChainRules(t, conn, inputV6, 1) } func TestNFTAddAndDelHookRule(t *testing.T) { @@ -960,32 +1032,32 @@ func TestPickFirewallModeFromInstalledRules(t *testing.T) { want FirewallMode }{ { - name: "using iptables legacy", + name: "using-iptables-legacy", det: &testFWDetector{iptRuleCount: 1}, want: FirewallModeIPTables, }, { - name: "using nftables", + name: "using-nftables", det: &testFWDetector{nftRuleCount: 1}, want: FirewallModeNfTables, }, { - name: "using both iptables and nftables", + name: "using-both-iptables-and-nftables", det: &testFWDetector{iptRuleCount: 2, nftRuleCount: 2}, want: FirewallModeNfTables, }, { - name: "not using any firewall, both available", + name: "no-firewall-both-available", det: &testFWDetector{}, want: FirewallModeNfTables, }, { - name: "not using any firewall, iptables available only", + name: "no-firewall-iptables-only", det: &testFWDetector{iptRuleCount: 1, nftErr: errors.New("nft error")}, want: FirewallModeIPTables, }, { - name: "not using any firewall, nftables available only", + name: "no-firewall-nftables-only", det: &testFWDetector{iptErr: errors.New("iptables error"), nftRuleCount: 1}, want: FirewallModeNfTables, }, @@ -1066,7 +1138,7 @@ func checkSNATRule_nft(t *testing.T, runner *nftablesRunner, fam nftables.TableF if chain == nil { t.Fatal("POSTROUTING chain does not exist") } - meta := []byte(fmt.Sprintf("dst:%s,src:%s", dst.String(), src.String())) + meta := fmt.Appendf(nil, "dst:%s,src:%s", dst.String(), src.String()) wantsRule := snatRule(chain.Table, chain, src, dst, meta) checkRule(t, wantsRule, runner.conn) } @@ -1313,3 +1385,39 @@ func TestMakeConnmarkSaveExprs(t *testing.T) { t.Fatalf("Flush() failed: %v", err) } } + +// TestGetOrCreateChainNilHooknum verifies that getOrCreateChain returns a clear +// error when a ts- chain exists but has nil Hooknum/Priority, which happens when +// the kernel lacks nftables support (CONFIG_NF_TABLES). +func TestGetOrCreateChainNilHooknum(t *testing.T) { + conn := newSysConn(t) + + table := conn.AddTable(&nftables.Table{ + Family: nftables.TableFamilyIPv4, + Name: "ts-filter-test", + }) + // Add a ts- chain without hooknum/priority (regular chain), simulating + // the broken state returned by a kernel without nftables support. + conn.AddChain(&nftables.Chain{ + Name: "ts-input", + Table: table, + }) + if err := conn.Flush(); err != nil { + t.Fatalf("Flush() failed: %v", err) + } + + // Now try getOrCreateChain expecting a base chain with hooknum/priority. + _, err := getOrCreateChain(conn, chainInfo{ + table: table, + name: "ts-input", + chainType: nftables.ChainTypeFilter, + chainHook: nftables.ChainHookInput, + chainPriority: nftables.ChainPriorityFilter, + }) + if err == nil { + t.Fatal("expected error for chain with nil hooknum/priority, got nil") + } + if !strings.Contains(err.Error(), "nil hooknum") { + t.Fatalf("unexpected error: %v", err) + } +} diff --git a/util/osdiag/osdiag_windows.go b/util/osdiag/osdiag_windows.go index d6ba1d30bb674..ff489989c1dbd 100644 --- a/util/osdiag/osdiag_windows.go +++ b/util/osdiag/osdiag_windows.go @@ -352,7 +352,7 @@ const ( ) // Note that wsaProtocolInfo needs to be identical to windows.WSAProtocolInfo; -// the purpose of this type is to have the ability to use it as a reciever in +// the purpose of this type is to have the ability to use it as a receiver in // the path and categoryFlags funcs defined below. type wsaProtocolInfo windows.WSAProtocolInfo diff --git a/util/osuser/group_ids.go b/util/osuser/group_ids.go index 2a1f147d87b00..34d15c926ae98 100644 --- a/util/osuser/group_ids.go +++ b/util/osuser/group_ids.go @@ -23,7 +23,7 @@ func GetGroupIds(user *user.User) ([]string, error) { return nil, nil } - if runtime.GOOS != "linux" { + if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" { return user.GroupIds() } @@ -46,13 +46,24 @@ func getGroupIdsWithId(usernameOrUID string) ([]string, error) { defer cancel() cmd := exec.CommandContext(ctx, "id", "-Gz", usernameOrUID) - out, err := cmd.Output() + if runtime.GOOS == "freebsd" { + cmd = exec.CommandContext(ctx, "id", "-G", usernameOrUID) + } + + out, err := cmd.CombinedOutput() if err != nil { return nil, fmt.Errorf("running 'id' command: %w", err) } + return parseGroupIds(out), nil } func parseGroupIds(cmdOutput []byte) []string { - return strings.Split(strings.Trim(string(cmdOutput), "\n\x00"), "\x00") + s := strings.TrimSpace(string(cmdOutput)) + // Parse NUL-delimited output. + if strings.ContainsRune(s, '\x00') { + return strings.Split(strings.Trim(s, "\x00"), "\x00") + } + // Parse whitespace-delimited output. + return strings.Fields(s) } diff --git a/util/osuser/group_ids_test.go b/util/osuser/group_ids_test.go index 79e189ed8c866..fee86029bf4dc 100644 --- a/util/osuser/group_ids_test.go +++ b/util/osuser/group_ids_test.go @@ -15,7 +15,9 @@ func TestParseGroupIds(t *testing.T) { }{ {"5000\x005001\n", []string{"5000", "5001"}}, {"5000\n", []string{"5000"}}, - {"\n", []string{""}}, + {"\n", []string{}}, + {"5000 5001 5002\n", []string{"5000", "5001", "5002"}}, + {"5000\t5001\n", []string{"5000", "5001"}}, } for _, test := range tests { actual := parseGroupIds([]byte(test.in)) diff --git a/util/pidowner/pidowner.go b/util/pidowner/pidowner.go deleted file mode 100644 index cec92ba367e49..0000000000000 --- a/util/pidowner/pidowner.go +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -// Package pidowner handles lookups from process ID to its owning user. -package pidowner - -import ( - "errors" - "runtime" -) - -var ErrNotImplemented = errors.New("not implemented for GOOS=" + runtime.GOOS) - -var ErrProcessNotFound = errors.New("process not found") - -// OwnerOfPID returns the user ID that owns the given process ID. -// -// The returned user ID is suitable to passing to os/user.LookupId. -// -// The returned error will be ErrNotImplemented for operating systems where -// this isn't supported. -func OwnerOfPID(pid int) (userID string, err error) { - return ownerOfPID(pid) -} diff --git a/util/pidowner/pidowner_linux.go b/util/pidowner/pidowner_linux.go deleted file mode 100644 index f3f5cd97ddcb2..0000000000000 --- a/util/pidowner/pidowner_linux.go +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -package pidowner - -import ( - "fmt" - "os" - "strings" - - "tailscale.com/util/lineiter" -) - -func ownerOfPID(pid int) (userID string, err error) { - file := fmt.Sprintf("/proc/%d/status", pid) - for lr := range lineiter.File(file) { - line, err := lr.Value() - if err != nil { - if os.IsNotExist(err) { - return "", ErrProcessNotFound - } - return "", err - } - if len(line) < 4 || string(line[:4]) != "Uid:" { - continue - } - f := strings.Fields(string(line)) - if len(f) >= 2 { - userID = f[1] // real userid - } - } - if userID == "" { - return "", fmt.Errorf("missing Uid line in %s", file) - } - return userID, nil -} diff --git a/util/pidowner/pidowner_noimpl.go b/util/pidowner/pidowner_noimpl.go deleted file mode 100644 index 4bc665d61071e..0000000000000 --- a/util/pidowner/pidowner_noimpl.go +++ /dev/null @@ -1,8 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows && !linux - -package pidowner - -func ownerOfPID(pid int) (userID string, err error) { return "", ErrNotImplemented } diff --git a/util/pidowner/pidowner_test.go b/util/pidowner/pidowner_test.go deleted file mode 100644 index 2774a8ab0fe36..0000000000000 --- a/util/pidowner/pidowner_test.go +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -package pidowner - -import ( - "math/rand" - "os" - "os/user" - "testing" -) - -func TestOwnerOfPID(t *testing.T) { - id, err := OwnerOfPID(os.Getpid()) - if err == ErrNotImplemented { - t.Skip(err) - } - if err != nil { - t.Fatal(err) - } - t.Logf("id=%q", id) - - u, err := user.LookupId(id) - if err != nil { - t.Fatalf("LookupId: %v", err) - } - t.Logf("Got: %+v", u) -} - -// validate that OS implementation returns ErrProcessNotFound. -func TestNotFoundError(t *testing.T) { - // Try a bunch of times to stumble upon a pid that doesn't exist... - const tries = 50 - for range tries { - _, err := OwnerOfPID(rand.Intn(1e9)) - if err == ErrNotImplemented { - t.Skip(err) - } - if err == nil { - // We got unlucky and this pid existed. Try again. - continue - } - if err == ErrProcessNotFound { - // Pass. - return - } - t.Fatalf("Error is not ErrProcessNotFound: %T %v", err, err) - } - t.Errorf("after %d tries, couldn't find a process that didn't exist", tries) -} diff --git a/util/pidowner/pidowner_windows.go b/util/pidowner/pidowner_windows.go deleted file mode 100644 index 8edd7698d4207..0000000000000 --- a/util/pidowner/pidowner_windows.go +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -package pidowner - -import ( - "fmt" - "syscall" - - "golang.org/x/sys/windows" -) - -func ownerOfPID(pid int) (userID string, err error) { - procHnd, err := windows.OpenProcess(windows.PROCESS_QUERY_INFORMATION, false, uint32(pid)) - if err == syscall.Errno(0x57) { // invalid parameter, for PIDs that don't exist - return "", ErrProcessNotFound - } - if err != nil { - return "", fmt.Errorf("OpenProcess: %T %#v", err, err) - } - defer windows.CloseHandle(procHnd) - - var tok windows.Token - if err := windows.OpenProcessToken(procHnd, windows.TOKEN_QUERY, &tok); err != nil { - return "", fmt.Errorf("OpenProcessToken: %w", err) - } - - tokUser, err := tok.GetTokenUser() - if err != nil { - return "", fmt.Errorf("GetTokenUser: %w", err) - } - - sid := tokUser.User.Sid - return sid.String(), nil -} diff --git a/util/pool/pool.go b/util/pool/pool.go deleted file mode 100644 index 7042fb893a59e..0000000000000 --- a/util/pool/pool.go +++ /dev/null @@ -1,213 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -// Package pool contains a generic type for managing a pool of resources; for -// example, connections to a database, or to a remote service. -// -// Unlike sync.Pool from the Go standard library, this pool does not remove -// items from the pool when garbage collection happens, nor is it safe for -// concurrent use like sync.Pool. -package pool - -import ( - "fmt" - "math/rand/v2" - - "tailscale.com/types/ptr" -) - -// consistencyCheck enables additional runtime checks to ensure that the pool -// is well-formed; it is disabled by default, and can be enabled during tests -// to catch additional bugs. -const consistencyCheck = false - -// Pool is a pool of resources. It is not safe for concurrent use. -type Pool[V any] struct { - s []itemAndIndex[V] -} - -type itemAndIndex[V any] struct { - // item is the element in the pool - item V - - // index is the current location of this item in pool.s. It gets set to - // -1 when the item is removed from the pool. - index *int -} - -// Handle is an opaque handle to a resource in a pool. It is used to delete an -// item from the pool, without requiring the item to be comparable. -type Handle[V any] struct { - idx *int // pointer to index; -1 if not in slice -} - -// Len returns the current size of the pool. -func (p *Pool[V]) Len() int { - return len(p.s) -} - -// Clear removes all items from the pool. -func (p *Pool[V]) Clear() { - p.s = nil -} - -// AppendTakeAll removes all items from the pool, appending them to the -// provided slice (which can be nil) and returning them. The returned slice can -// be nil if the provided slice was nil and the pool was empty. -// -// This function does not free the backing storage for the pool; to do that, -// use the Clear function. -func (p *Pool[V]) AppendTakeAll(dst []V) []V { - ret := dst - for i := range p.s { - e := p.s[i] - if consistencyCheck && e.index == nil { - panic(fmt.Sprintf("pool: index is nil at %d", i)) - } - if *e.index >= 0 { - ret = append(ret, p.s[i].item) - } - } - p.s = p.s[:0] - return ret -} - -// Add adds an item to the pool and returns a handle to it. The handle can be -// used to delete the item from the pool with the Delete method. -func (p *Pool[V]) Add(item V) Handle[V] { - // Store the index in a pointer, so that we can pass it to both the - // handle and store it in the itemAndIndex. - idx := ptr.To(len(p.s)) - p.s = append(p.s, itemAndIndex[V]{ - item: item, - index: idx, - }) - return Handle[V]{idx} -} - -// Peek will return the item with the given handle without removing it from the -// pool. -// -// It will return ok=false if the item has been deleted or previously taken. -func (p *Pool[V]) Peek(h Handle[V]) (v V, ok bool) { - p.checkHandle(h) - idx := *h.idx - if idx < 0 { - var zero V - return zero, false - } - p.checkIndex(idx) - return p.s[idx].item, true -} - -// Delete removes the item from the pool. -// -// It reports whether the element was deleted; it will return false if the item -// has been taken with the TakeRandom function, or if the item was already -// deleted. -func (p *Pool[V]) Delete(h Handle[V]) bool { - p.checkHandle(h) - idx := *h.idx - if idx < 0 { - return false - } - p.deleteIndex(idx) - return true -} - -func (p *Pool[V]) deleteIndex(idx int) { - // Mark the item as deleted. - p.checkIndex(idx) - *(p.s[idx].index) = -1 - - // If this isn't the last element in the slice, overwrite the element - // at this item's index with the last element. - lastIdx := len(p.s) - 1 - - if idx < lastIdx { - last := p.s[lastIdx] - p.checkElem(lastIdx, last) - *last.index = idx - p.s[idx] = last - } - - // Zero out last element (for GC) and truncate slice. - p.s[lastIdx] = itemAndIndex[V]{} - p.s = p.s[:lastIdx] -} - -// Take will remove the item with the given handle from the pool and return it. -// -// It will return ok=false and the zero value if the item has been deleted or -// previously taken. -func (p *Pool[V]) Take(h Handle[V]) (v V, ok bool) { - p.checkHandle(h) - idx := *h.idx - if idx < 0 { - var zero V - return zero, false - } - - e := p.s[idx] - p.deleteIndex(idx) - return e.item, true -} - -// TakeRandom returns and removes a random element from p -// and reports whether there was one to take. -// -// It will return ok=false and the zero value if the pool is empty. -func (p *Pool[V]) TakeRandom() (v V, ok bool) { - if len(p.s) == 0 { - var zero V - return zero, false - } - pick := rand.IntN(len(p.s)) - e := p.s[pick] - p.checkElem(pick, e) - p.deleteIndex(pick) - return e.item, true -} - -// checkIndex verifies that the provided index is within the bounds of the -// pool's slice, and that the corresponding element has a non-nil index -// pointer, and panics if not. -func (p *Pool[V]) checkIndex(idx int) { - if !consistencyCheck { - return - } - - if idx >= len(p.s) { - panic(fmt.Sprintf("pool: index %d out of range (len %d)", idx, len(p.s))) - } - if p.s[idx].index == nil { - panic(fmt.Sprintf("pool: index is nil at %d", idx)) - } -} - -// checkHandle verifies that the provided handle is not nil, and panics if it -// is. -func (p *Pool[V]) checkHandle(h Handle[V]) { - if !consistencyCheck { - return - } - - if h.idx == nil { - panic("pool: nil handle") - } -} - -// checkElem verifies that the provided itemAndIndex has a non-nil index, and -// that the stored index matches the expected position within the slice. -func (p *Pool[V]) checkElem(idx int, e itemAndIndex[V]) { - if !consistencyCheck { - return - } - - if e.index == nil { - panic("pool: index is nil") - } - if got := *e.index; got != idx { - panic(fmt.Sprintf("pool: index is incorrect: want %d, got %d", idx, got)) - } -} diff --git a/util/pool/pool_test.go b/util/pool/pool_test.go deleted file mode 100644 index ac7cf86be3ef7..0000000000000 --- a/util/pool/pool_test.go +++ /dev/null @@ -1,203 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -package pool - -import ( - "slices" - "testing" -) - -func TestPool(t *testing.T) { - p := Pool[int]{} - - if got, want := p.Len(), 0; got != want { - t.Errorf("got initial length %v; want %v", got, want) - } - - h1 := p.Add(101) - h2 := p.Add(102) - h3 := p.Add(103) - h4 := p.Add(104) - - if got, want := p.Len(), 4; got != want { - t.Errorf("got length %v; want %v", got, want) - } - - tests := []struct { - h Handle[int] - want int - }{ - {h1, 101}, - {h2, 102}, - {h3, 103}, - {h4, 104}, - } - for i, test := range tests { - got, ok := p.Peek(test.h) - if !ok { - t.Errorf("test[%d]: did not find item", i) - continue - } - if got != test.want { - t.Errorf("test[%d]: got %v; want %v", i, got, test.want) - } - } - - if deleted := p.Delete(h2); !deleted { - t.Errorf("h2 not deleted") - } - if deleted := p.Delete(h2); deleted { - t.Errorf("h2 should not be deleted twice") - } - if got, want := p.Len(), 3; got != want { - t.Errorf("got length %v; want %v", got, want) - } - if _, ok := p.Peek(h2); ok { - t.Errorf("h2 still in pool") - } - - // Remove an item by handle - got, ok := p.Take(h4) - if !ok { - t.Errorf("h4 not found") - } - if got != 104 { - t.Errorf("got %v; want 104", got) - } - - // Take doesn't work on previously-taken or deleted items. - if _, ok := p.Take(h4); ok { - t.Errorf("h4 should not be taken twice") - } - if _, ok := p.Take(h2); ok { - t.Errorf("h2 should not be taken after delete") - } - - // Remove all items and return them - items := p.AppendTakeAll(nil) - want := []int{101, 103} - if !slices.Equal(items, want) { - t.Errorf("got items %v; want %v", items, want) - } - if got := p.Len(); got != 0 { - t.Errorf("got length %v; want 0", got) - } - - // Insert and then clear should result in no items. - p.Add(105) - p.Clear() - if got := p.Len(); got != 0 { - t.Errorf("got length %v; want 0", got) - } -} - -func TestTakeRandom(t *testing.T) { - p := Pool[int]{} - for i := 0; i < 10; i++ { - p.Add(i + 100) - } - - seen := make(map[int]bool) - for i := 0; i < 10; i++ { - item, ok := p.TakeRandom() - if !ok { - t.Errorf("unexpected empty pool") - break - } - if seen[item] { - t.Errorf("got duplicate item %v", item) - } - seen[item] = true - } - - // Verify that the pool is empty - if _, ok := p.TakeRandom(); ok { - t.Errorf("expected empty pool") - } - - for i := 0; i < 10; i++ { - want := 100 + i - if !seen[want] { - t.Errorf("item %v not seen", want) - } - } - - if t.Failed() { - t.Logf("seen: %+v", seen) - } -} - -func BenchmarkPool_AddDelete(b *testing.B) { - b.Run("impl=Pool", func(b *testing.B) { - p := Pool[int]{} - - // Warm up/force an initial allocation - h := p.Add(0) - p.Delete(h) - - b.ResetTimer() - - for i := 0; i < b.N; i++ { - h := p.Add(i) - p.Delete(h) - } - }) - b.Run("impl=map", func(b *testing.B) { - p := make(map[int]bool) - - // Force initial allocation - p[0] = true - delete(p, 0) - - b.ResetTimer() - - for i := 0; i < b.N; i++ { - p[i] = true - delete(p, i) - } - }) -} - -func BenchmarkPool_TakeRandom(b *testing.B) { - b.Run("impl=Pool", func(b *testing.B) { - p := Pool[int]{} - - // Insert the number of items we'll be taking, then reset the timer. - for i := 0; i < b.N; i++ { - p.Add(i) - } - b.ResetTimer() - - // Now benchmark taking all the items. - for i := 0; i < b.N; i++ { - p.TakeRandom() - } - - if p.Len() != 0 { - b.Errorf("pool not empty") - } - }) - b.Run("impl=map", func(b *testing.B) { - p := make(map[int]bool) - - // Insert the number of items we'll be taking, then reset the timer. - for i := 0; i < b.N; i++ { - p[i] = true - } - b.ResetTimer() - - // Now benchmark taking all the items. - for i := 0; i < b.N; i++ { - // Taking a random item is simulated by a single map iteration. - for k := range p { - delete(p, k) // "take" the item by removing it - break - } - } - - if len(p) != 0 { - b.Errorf("map not empty") - } - }) -} diff --git a/util/set/intset.go b/util/set/intset.go index 04f614742e796..29a634516a510 100644 --- a/util/set/intset.go +++ b/util/set/intset.go @@ -152,7 +152,7 @@ func (s bitSet) values() iter.Seq[uint64] { return func(yield func(uint64) bool) { // Hyrum-proofing: randomly iterate in forwards or reverse. if rand.Uint64()%2 == 0 { - for i := 0; i < bits.UintSize; i++ { + for i := range bits.UintSize { if s.contains(uint64(i)) && !yield(uint64(i)) { return } diff --git a/util/singleflight/singleflight.go b/util/singleflight/singleflight.go index 23cf7e21fec15..e6d859178140b 100644 --- a/util/singleflight/singleflight.go +++ b/util/singleflight/singleflight.go @@ -36,7 +36,7 @@ var errGoexit = errors.New("runtime.Goexit was called") // A panicError is an arbitrary value recovered from a panic // with the stack trace during the execution of given function. type panicError struct { - value interface{} + value any stack []byte } @@ -45,7 +45,7 @@ func (p *panicError) Error() string { return fmt.Sprintf("%v\n\n%s", p.value, p.stack) } -func newPanicError(v interface{}) error { +func newPanicError(v any) error { stack := debug.Stack() // The first line of the stack trace is of the form "goroutine N [status]:" diff --git a/util/singleflight/singleflight_test.go b/util/singleflight/singleflight_test.go index 9f0ca7f1de853..4e8500cc3c3d6 100644 --- a/util/singleflight/singleflight_test.go +++ b/util/singleflight/singleflight_test.go @@ -25,7 +25,7 @@ import ( func TestDo(t *testing.T) { var g Group[string, any] - v, err, _ := g.Do("key", func() (interface{}, error) { + v, err, _ := g.Do("key", func() (any, error) { return "bar", nil }) if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want { @@ -39,7 +39,7 @@ func TestDo(t *testing.T) { func TestDoErr(t *testing.T) { var g Group[string, any] someErr := errors.New("Some error") - v, err, _ := g.Do("key", func() (interface{}, error) { + v, err, _ := g.Do("key", func() (any, error) { return nil, someErr }) if err != someErr { @@ -55,7 +55,7 @@ func TestDoDupSuppress(t *testing.T) { var wg1, wg2 sync.WaitGroup c := make(chan string, 1) var calls int32 - fn := func() (interface{}, error) { + fn := func() (any, error) { if atomic.AddInt32(&calls, 1) == 1 { // First invocation. wg1.Done() @@ -72,9 +72,7 @@ func TestDoDupSuppress(t *testing.T) { wg1.Add(1) for range n { wg1.Add(1) - wg2.Add(1) - go func() { - defer wg2.Done() + wg2.Go(func() { wg1.Done() v, err, _ := g.Do("key", fn) if err != nil { @@ -84,7 +82,7 @@ func TestDoDupSuppress(t *testing.T) { if s, _ := v.(string); s != "bar" { t.Errorf("Do = %T %v; want %q", v, v, "bar") } - }() + }) } wg1.Wait() // At least one goroutine is in fn now and all of them have at @@ -108,7 +106,7 @@ func TestForget(t *testing.T) { ) go func() { - g.Do("key", func() (i interface{}, e error) { + g.Do("key", func() (i any, e error) { close(firstStarted) <-unblockFirst close(firstFinished) @@ -119,7 +117,7 @@ func TestForget(t *testing.T) { g.Forget("key") unblockSecond := make(chan struct{}) - secondResult := g.DoChan("key", func() (i interface{}, e error) { + secondResult := g.DoChan("key", func() (i any, e error) { <-unblockSecond return 2, nil }) @@ -127,7 +125,7 @@ func TestForget(t *testing.T) { close(unblockFirst) <-firstFinished - thirdResult := g.DoChan("key", func() (i interface{}, e error) { + thirdResult := g.DoChan("key", func() (i any, e error) { return 3, nil }) @@ -141,7 +139,7 @@ func TestForget(t *testing.T) { func TestDoChan(t *testing.T) { var g Group[string, any] - ch := g.DoChan("key", func() (interface{}, error) { + ch := g.DoChan("key", func() (any, error) { return "bar", nil }) @@ -160,7 +158,7 @@ func TestDoChan(t *testing.T) { // See https://github.com/golang/go/issues/41133 func TestPanicDo(t *testing.T) { var g Group[string, any] - fn := func() (interface{}, error) { + fn := func() (any, error) { panic("invalid memory address or nil pointer dereference") } @@ -197,7 +195,7 @@ func TestPanicDo(t *testing.T) { func TestGoexitDo(t *testing.T) { var g Group[string, any] - fn := func() (interface{}, error) { + fn := func() (any, error) { runtime.Goexit() return nil, nil } @@ -238,7 +236,7 @@ func TestPanicDoChan(t *testing.T) { }() g := new(Group[string, any]) - ch := g.DoChan("", func() (interface{}, error) { + ch := g.DoChan("", func() (any, error) { panic("Panicking in DoChan") }) <-ch @@ -283,7 +281,7 @@ func TestPanicDoSharedByDoChan(t *testing.T) { defer func() { recover() }() - g.Do("", func() (interface{}, error) { + g.Do("", func() (any, error) { close(blocked) <-unblock panic("Panicking in Do") @@ -291,7 +289,7 @@ func TestPanicDoSharedByDoChan(t *testing.T) { }() <-blocked - ch := g.DoChan("", func() (interface{}, error) { + ch := g.DoChan("", func() (any, error) { panic("DoChan unexpectedly executed callback") }) close(unblock) @@ -325,8 +323,7 @@ func TestPanicDoSharedByDoChan(t *testing.T) { func TestDoChanContext(t *testing.T) { t.Run("Basic", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() var g Group[string, int] ch := g.DoChanContext(ctx, "key", func(_ context.Context) (int, error) { @@ -337,8 +334,7 @@ func TestDoChanContext(t *testing.T) { }) t.Run("DoesNotPropagateValues", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() key := new(int) const value = "hello world" @@ -364,8 +360,7 @@ func TestDoChanContext(t *testing.T) { ctx1, cancel1 := context.WithCancel(context.Background()) defer cancel1() - ctx2, cancel2 := context.WithCancel(context.Background()) - defer cancel2() + ctx2 := t.Context() fn := func(ctx context.Context) (int, error) { select { diff --git a/util/slicesx/slicesx_test.go b/util/slicesx/slicesx_test.go index d5c87a3727748..6b28c29b47382 100644 --- a/util/slicesx/slicesx_test.go +++ b/util/slicesx/slicesx_test.go @@ -53,7 +53,7 @@ func TestShuffle(t *testing.T) { } var wasShuffled bool - for try := 0; try < 10; try++ { + for range 10 { shuffled := slices.Clone(sl) Shuffle(shuffled) if !reflect.DeepEqual(shuffled, sl) { diff --git a/util/syspolicy/policytest/policytest.go b/util/syspolicy/policytest/policytest.go index ef5ce889dd2de..9879a0fd3c69c 100644 --- a/util/syspolicy/policytest/policytest.go +++ b/util/syspolicy/policytest/policytest.go @@ -89,12 +89,7 @@ func (pc policyChanges) HasChanged(v pkey.Key) bool { return ok } func (pc policyChanges) HasChangedAnyOf(keys ...pkey.Key) bool { - for _, k := range keys { - if pc.HasChanged(k) { - return true - } - } - return false + return slices.ContainsFunc(keys, pc.HasChanged) } const watchersKey = "_policytest_watchers" diff --git a/util/syspolicy/setting/errors.go b/util/syspolicy/setting/errors.go index 655018d4b5aff..c8e0d8121ec2a 100644 --- a/util/syspolicy/setting/errors.go +++ b/util/syspolicy/setting/errors.go @@ -5,8 +5,6 @@ package setting import ( "errors" - - "tailscale.com/types/ptr" ) var ( @@ -39,7 +37,7 @@ type ErrorText string // NewErrorText returns a [ErrorText] with the specified error message. func NewErrorText(text string) *ErrorText { - return ptr.To(ErrorText(text)) + return new(ErrorText(text)) } // MaybeErrorText returns an [ErrorText] with the text of the specified error, @@ -51,7 +49,7 @@ func MaybeErrorText(err error) *ErrorText { if err, ok := err.(*ErrorText); ok { return err } - return ptr.To(ErrorText(err.Error())) + return new(ErrorText(err.Error())) } // Error implements error. diff --git a/util/syspolicy/setting/policy_scope_test.go b/util/syspolicy/setting/policy_scope_test.go index a2f6328151d05..9cdbbe7ab5df4 100644 --- a/util/syspolicy/setting/policy_scope_test.go +++ b/util/syspolicy/setting/policy_scope_test.go @@ -226,42 +226,42 @@ func TestPolicyScopeContains(t *testing.T) { wantAStrictlyContainsB: false, }, { - name: "UserScope(1234)/UserScope(1234)", + name: "UserScope-1234/UserScope-1234", scopeA: UserScopeOf("1234"), scopeB: UserScopeOf("1234"), wantAContainsB: true, wantAStrictlyContainsB: false, }, { - name: "UserScope(1234)/UserScope(5678)", + name: "UserScope-1234/UserScope-5678", scopeA: UserScopeOf("1234"), scopeB: UserScopeOf("5678"), wantAContainsB: false, wantAStrictlyContainsB: false, }, { - name: "ProfileScope(A)/UserScope(A/1234)", + name: "ProfileScope-A/UserScope-A-1234", scopeA: PolicyScope{kind: ProfileSetting, profileID: "A"}, scopeB: PolicyScope{kind: UserSetting, userID: "1234", profileID: "A"}, wantAContainsB: true, wantAStrictlyContainsB: true, }, { - name: "ProfileScope(A)/UserScope(B/1234)", + name: "ProfileScope-A/UserScope-B-1234", scopeA: PolicyScope{kind: ProfileSetting, profileID: "A"}, scopeB: PolicyScope{kind: UserSetting, userID: "1234", profileID: "B"}, wantAContainsB: false, wantAStrictlyContainsB: false, }, { - name: "UserScope(1234)/UserScope(A/1234)", + name: "UserScope-1234/UserScope-A-1234", scopeA: PolicyScope{kind: UserSetting, userID: "1234"}, scopeB: PolicyScope{kind: UserSetting, userID: "1234", profileID: "A"}, wantAContainsB: true, wantAStrictlyContainsB: true, }, { - name: "UserScope(1234)/UserScope(A/5678)", + name: "UserScope-1234/UserScope-A-5678", scopeA: PolicyScope{kind: UserSetting, userID: "1234"}, scopeB: PolicyScope{kind: UserSetting, userID: "5678", profileID: "A"}, wantAContainsB: false, diff --git a/util/syspolicy/setting/setting_test.go b/util/syspolicy/setting/setting_test.go index 3ccd2ef606c50..885491b679ba8 100644 --- a/util/syspolicy/setting/setting_test.go +++ b/util/syspolicy/setting/setting_test.go @@ -9,7 +9,6 @@ import ( "testing" "tailscale.com/types/lazy" - "tailscale.com/types/ptr" "tailscale.com/util/syspolicy/internal" "tailscale.com/util/syspolicy/pkey" ) @@ -138,7 +137,7 @@ func TestSettingDefinition(t *testing.T) { if !tt.setting.Equal(tt.setting) { t.Errorf("the setting should be equal to itself") } - if tt.setting != nil && !tt.setting.Equal(ptr.To(*tt.setting)) { + if tt.setting != nil && !tt.setting.Equal(new(*tt.setting)) { t.Errorf("the setting should be equal to its shallow copy") } if gotKey := tt.setting.Key(); gotKey != tt.wantKey { diff --git a/util/syspolicy/syspolicy_test.go b/util/syspolicy/syspolicy_test.go index 532cd03b8b9a7..c62c90dddfb33 100644 --- a/util/syspolicy/syspolicy_test.go +++ b/util/syspolicy/syspolicy_test.go @@ -44,7 +44,7 @@ func TestGetString(t *testing.T) { wantMetrics []metrics.TestState }{ { - name: "read existing value", + name: "read-existing-value", key: pkey.AdminConsoleVisibility, handlerValue: "hide", wantValue: "hide", @@ -54,13 +54,13 @@ func TestGetString(t *testing.T) { }, }, { - name: "read non-existing value", + name: "read-non-existing-value", key: pkey.EnableServerMode, handlerError: ErrNotConfigured, wantError: nil, }, { - name: "read non-existing value, non-blank default", + name: "read-non-existing-value-non-blank-default", key: pkey.EnableServerMode, handlerError: ErrNotConfigured, defaultValue: "test", @@ -68,7 +68,7 @@ func TestGetString(t *testing.T) { wantError: nil, }, { - name: "reading value returns other error", + name: "reading-value-returns-other-error", key: pkey.NetworkDevicesVisibility, handlerError: someOtherError, wantError: someOtherError, @@ -124,27 +124,27 @@ func TestGetUint64(t *testing.T) { wantError error }{ { - name: "read existing value", + name: "read-existing-value", key: pkey.LogSCMInteractions, handlerValue: 1, wantValue: 1, }, { - name: "read non-existing value", + name: "read-non-existing-value", key: pkey.LogSCMInteractions, handlerValue: 0, handlerError: ErrNotConfigured, wantValue: 0, }, { - name: "read non-existing value, non-zero default", + name: "read-non-existing-value-non-zero-default", key: pkey.LogSCMInteractions, defaultValue: 2, handlerError: ErrNotConfigured, wantValue: 2, }, { - name: "reading value returns other error", + name: "reading-value-returns-other-error", key: pkey.FlushDNSOnSessionUnlock, handlerError: someOtherError, wantError: someOtherError, @@ -191,7 +191,7 @@ func TestGetBoolean(t *testing.T) { wantMetrics []metrics.TestState }{ { - name: "read existing value", + name: "read-existing-value", key: pkey.FlushDNSOnSessionUnlock, handlerValue: true, wantValue: true, @@ -201,14 +201,14 @@ func TestGetBoolean(t *testing.T) { }, }, { - name: "read non-existing value", + name: "read-non-existing-value", key: pkey.LogSCMInteractions, handlerValue: false, handlerError: ErrNotConfigured, wantValue: false, }, { - name: "reading value returns other error", + name: "reading-value-returns-other-error", key: pkey.FlushDNSOnSessionUnlock, handlerError: someOtherError, wantError: someOtherError, // expect error... @@ -266,7 +266,7 @@ func TestGetPreferenceOption(t *testing.T) { wantMetrics []metrics.TestState }{ { - name: "always by policy", + name: "always-by-policy", key: pkey.EnableIncomingConnections, handlerValue: "always", wantValue: ptype.AlwaysByPolicy, @@ -276,7 +276,7 @@ func TestGetPreferenceOption(t *testing.T) { }, }, { - name: "never by policy", + name: "never-by-policy", key: pkey.EnableIncomingConnections, handlerValue: "never", wantValue: ptype.NeverByPolicy, @@ -286,7 +286,7 @@ func TestGetPreferenceOption(t *testing.T) { }, }, { - name: "use default", + name: "use-default", key: pkey.EnableIncomingConnections, handlerValue: "", wantValue: ptype.ShowChoiceByPolicy, @@ -296,13 +296,13 @@ func TestGetPreferenceOption(t *testing.T) { }, }, { - name: "read non-existing value", + name: "read-non-existing-value", key: pkey.EnableIncomingConnections, handlerError: ErrNotConfigured, wantValue: ptype.ShowChoiceByPolicy, }, { - name: "other error is returned", + name: "other-error-is-returned", key: pkey.EnableIncomingConnections, handlerError: someOtherError, wantValue: ptype.ShowChoiceByPolicy, @@ -359,7 +359,7 @@ func TestGetVisibility(t *testing.T) { wantMetrics []metrics.TestState }{ { - name: "hidden by policy", + name: "hidden-by-policy", key: pkey.AdminConsoleVisibility, handlerValue: "hide", wantValue: ptype.HiddenByPolicy, @@ -369,7 +369,7 @@ func TestGetVisibility(t *testing.T) { }, }, { - name: "visibility default", + name: "visibility-default", key: pkey.AdminConsoleVisibility, handlerValue: "show", wantValue: ptype.VisibleByPolicy, @@ -379,14 +379,14 @@ func TestGetVisibility(t *testing.T) { }, }, { - name: "read non-existing value", + name: "read-non-existing-value", key: pkey.AdminConsoleVisibility, handlerValue: "show", handlerError: ErrNotConfigured, wantValue: ptype.VisibleByPolicy, }, { - name: "other error is returned", + name: "other-error-is-returned", key: pkey.AdminConsoleVisibility, handlerValue: "show", handlerError: someOtherError, @@ -445,7 +445,7 @@ func TestGetDuration(t *testing.T) { wantMetrics []metrics.TestState }{ { - name: "read existing value", + name: "read-existing-value", key: pkey.KeyExpirationNoticeTime, handlerValue: "2h", wantValue: 2 * time.Hour, @@ -456,7 +456,7 @@ func TestGetDuration(t *testing.T) { }, }, { - name: "invalid duration value", + name: "invalid-duration-value", key: pkey.KeyExpirationNoticeTime, handlerValue: "-20", wantValue: 24 * time.Hour, @@ -468,21 +468,21 @@ func TestGetDuration(t *testing.T) { }, }, { - name: "read non-existing value", + name: "read-non-existing-value", key: pkey.KeyExpirationNoticeTime, handlerError: ErrNotConfigured, wantValue: 24 * time.Hour, defaultValue: 24 * time.Hour, }, { - name: "read non-existing value different default", + name: "read-non-existing-value-different-default", key: pkey.KeyExpirationNoticeTime, handlerError: ErrNotConfigured, wantValue: 0 * time.Second, defaultValue: 0 * time.Second, }, { - name: "other error is returned", + name: "other-error-is-returned", key: pkey.KeyExpirationNoticeTime, handlerError: someOtherError, wantValue: 24 * time.Hour, @@ -541,7 +541,7 @@ func TestGetStringArray(t *testing.T) { wantMetrics []metrics.TestState }{ { - name: "read existing value", + name: "read-existing-value", key: pkey.AllowedSuggestedExitNodes, handlerValue: []string{"foo", "bar"}, wantValue: []string{"foo", "bar"}, @@ -551,13 +551,13 @@ func TestGetStringArray(t *testing.T) { }, }, { - name: "read non-existing value", + name: "read-non-existing-value", key: pkey.AllowedSuggestedExitNodes, handlerError: ErrNotConfigured, wantError: nil, }, { - name: "read non-existing value, non nil default", + name: "read-non-existing-value-non-nil-default", key: pkey.AllowedSuggestedExitNodes, handlerError: ErrNotConfigured, defaultValue: []string{"foo", "bar"}, @@ -565,7 +565,7 @@ func TestGetStringArray(t *testing.T) { wantError: nil, }, { - name: "reading value returns other error", + name: "reading-value-returns-other-error", key: pkey.AllowedSuggestedExitNodes, handlerError: someOtherError, wantError: someOtherError, diff --git a/util/sysresources/memory.go b/util/sysresources/memory.go deleted file mode 100644 index 3c6b9ae852e47..0000000000000 --- a/util/sysresources/memory.go +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -package sysresources - -// TotalMemory returns the total accessible system memory, in bytes. If the -// value cannot be determined, then 0 will be returned. -func TotalMemory() uint64 { - return totalMemoryImpl() -} diff --git a/util/sysresources/memory_bsd.go b/util/sysresources/memory_bsd.go deleted file mode 100644 index 945f86ea35ec9..0000000000000 --- a/util/sysresources/memory_bsd.go +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -//go:build freebsd || openbsd || dragonfly || netbsd - -package sysresources - -import "golang.org/x/sys/unix" - -func totalMemoryImpl() uint64 { - val, err := unix.SysctlUint64("hw.physmem") - if err != nil { - return 0 - } - return val -} diff --git a/util/sysresources/memory_darwin.go b/util/sysresources/memory_darwin.go deleted file mode 100644 index 165f12eb3b808..0000000000000 --- a/util/sysresources/memory_darwin.go +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -//go:build darwin - -package sysresources - -import "golang.org/x/sys/unix" - -func totalMemoryImpl() uint64 { - val, err := unix.SysctlUint64("hw.memsize") - if err != nil { - return 0 - } - return val -} diff --git a/util/sysresources/memory_linux.go b/util/sysresources/memory_linux.go deleted file mode 100644 index 3885a8aa6c66e..0000000000000 --- a/util/sysresources/memory_linux.go +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux - -package sysresources - -import "golang.org/x/sys/unix" - -func totalMemoryImpl() uint64 { - var info unix.Sysinfo_t - - if err := unix.Sysinfo(&info); err != nil { - return 0 - } - - // uint64 casts are required since these might be uint32s - return uint64(info.Totalram) * uint64(info.Unit) -} diff --git a/util/sysresources/memory_unsupported.go b/util/sysresources/memory_unsupported.go deleted file mode 100644 index c88e9ed5201e9..0000000000000 --- a/util/sysresources/memory_unsupported.go +++ /dev/null @@ -1,8 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !(linux || darwin || freebsd || openbsd || dragonfly || netbsd) - -package sysresources - -func totalMemoryImpl() uint64 { return 0 } diff --git a/util/sysresources/sysresources.go b/util/sysresources/sysresources.go deleted file mode 100644 index 33d0d5d96a96e..0000000000000 --- a/util/sysresources/sysresources.go +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -// Package sysresources provides OS-independent methods of determining the -// resources available to the current system. -package sysresources diff --git a/util/sysresources/sysresources_test.go b/util/sysresources/sysresources_test.go deleted file mode 100644 index 7fea1bf0f5b32..0000000000000 --- a/util/sysresources/sysresources_test.go +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -package sysresources - -import ( - "runtime" - "testing" -) - -func TestTotalMemory(t *testing.T) { - switch runtime.GOOS { - case "linux": - case "freebsd", "openbsd", "dragonfly", "netbsd": - case "darwin": - default: - t.Skipf("not supported on runtime.GOOS=%q yet", runtime.GOOS) - } - - mem := TotalMemory() - if mem == 0 { - t.Fatal("wanted TotalMemory > 0") - } - t.Logf("total memory: %v bytes", mem) -} diff --git a/util/topk/topk.go b/util/topk/topk.go deleted file mode 100644 index 95ebd895d05aa..0000000000000 --- a/util/topk/topk.go +++ /dev/null @@ -1,261 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -// Package topk defines a count-min sketch and a cheap probabilistic top-K data -// structure that uses the count-min sketch to track the top K items in -// constant memory and O(log(k)) time. -package topk - -import ( - "container/heap" - "hash/maphash" - "math" - "slices" - "sync" -) - -// TopK is a probabilistic counter of the top K items, using a count-min sketch -// to keep track of item counts and a heap to track the top K of them. -type TopK[T any] struct { - heap minHeap[T] - k int - sf SerializeFunc[T] - cms CountMinSketch -} - -// HashFunc is responsible for providing a []byte serialization of a value, -// appended to the provided byte slice. This is used for hashing the value when -// adding to a CountMinSketch. -type SerializeFunc[T any] func([]byte, T) []byte - -// New creates a new TopK that stores k values. Parameters for the underlying -// count-min sketch are chosen for a 0.1% error rate and a 0.1% probability of -// error. -func New[T any](k int, sf SerializeFunc[T]) *TopK[T] { - hashes, buckets := PickParams(0.001, 0.001) - return NewWithParams(k, sf, hashes, buckets) -} - -// NewWithParams creates a new TopK that stores k values, and additionally -// allows customizing the parameters for the underlying count-min sketch. -func NewWithParams[T any](k int, sf SerializeFunc[T], numHashes, numCols int) *TopK[T] { - ret := &TopK[T]{ - heap: make(minHeap[T], 0, k), - k: k, - sf: sf, - } - ret.cms.init(numHashes, numCols) - return ret -} - -// Add calls AddN(val, 1). -func (tk *TopK[T]) Add(val T) uint64 { - return tk.AddN(val, 1) -} - -var hashPool = &sync.Pool{ - New: func() any { - buf := make([]byte, 0, 128) - return &buf - }, -} - -// AddN adds the given item to the set with the provided count, returning the -// new estimated count. -func (tk *TopK[T]) AddN(val T, count uint64) uint64 { - buf := hashPool.Get().(*[]byte) - defer hashPool.Put(buf) - ser := tk.sf((*buf)[:0], val) - - vcount := tk.cms.AddN(ser, count) - - // If we don't have a full heap, just push it. - if len(tk.heap) < tk.k { - heap.Push(&tk.heap, mhValue[T]{ - count: vcount, - val: val, - }) - return vcount - } - - // If this item's count surpasses the heap's minimum, update the heap. - if vcount > tk.heap[0].count { - tk.heap[0] = mhValue[T]{ - count: vcount, - val: val, - } - heap.Fix(&tk.heap, 0) - } - return vcount -} - -// Top returns the estimated top K items as stored by this TopK. -func (tk *TopK[T]) Top() []T { - ret := make([]T, 0, tk.k) - for _, item := range tk.heap { - ret = append(ret, item.val) - } - return ret -} - -// AppendTop appends the estimated top K items as stored by this TopK to the -// provided slice, allocating only if the slice does not have enough capacity -// to store all items. The provided slice can be nil. -func (tk *TopK[T]) AppendTop(sl []T) []T { - sl = slices.Grow(sl, tk.k) - for _, item := range tk.heap { - sl = append(sl, item.val) - } - return sl -} - -// CountMinSketch implements a count-min sketch, a probabilistic data structure -// that tracks the frequency of events in a stream of data. -// -// See: https://en.wikipedia.org/wiki/Count%E2%80%93min_sketch -type CountMinSketch struct { - hashes []maphash.Seed - nbuckets int - matrix []uint64 -} - -// NewCountMinSketch creates a new CountMinSketch with the provided number of -// hashes and buckets. Hashes and buckets are often called "depth" and "width", -// or "d" and "w", respectively. -func NewCountMinSketch(hashes, buckets int) *CountMinSketch { - ret := &CountMinSketch{} - ret.init(hashes, buckets) - return ret -} - -// PickParams provides good parameters for 'hashes' and 'buckets' when -// constructing a CountMinSketch, given an estimated total number of counts -// (i.e. the sum of all counts ever stored), the error factor Īĩ as a float -// (e.g. 1% is 0.001), and the probability factor δ. -// -// Parameters are chosen such that with a probability of 1−δ, the error is at -// most Īĩ∗totalCount. Or, in other words: if N is the true count of an event, -// E is the estimate given by a sketch and T the total count of items in the -// sketch, E ≤ N + T*Īĩ with probability (1 - δ). -func PickParams(err, probability float64) (hashes, buckets int) { - d := math.Ceil(math.Log(1 / probability)) - w := math.Ceil(math.E / err) - - return int(d), int(w) -} - -func (cms *CountMinSketch) init(hashes, buckets int) { - for range hashes { - cms.hashes = append(cms.hashes, maphash.MakeSeed()) - } - - // Need a matrix of hashes * buckets to store counts - cms.nbuckets = buckets - cms.matrix = make([]uint64, hashes*buckets) -} - -// Add calls AddN(val, 1). -func (cms *CountMinSketch) Add(val []byte) uint64 { - return cms.AddN(val, 1) -} - -// AddN increments the count for the given value by the provided count, -// returning the new count. -func (cms *CountMinSketch) AddN(val []byte, count uint64) uint64 { - var ( - mh maphash.Hash - ret uint64 = math.MaxUint64 - ) - for i, seed := range cms.hashes { - mh.SetSeed(seed) - - // Generate a hash for this value using Lemire's alternative to modular reduction: - // https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/ - mh.Write(val) - hash := mh.Sum64() - hash = multiplyHigh64(hash, uint64(cms.nbuckets)) - - // The index in our matrix is (i * buckets) to move "down" i - // rows in our matrix to the row for this hash, plus 'hash' to - // move inside this row. - idx := (i * cms.nbuckets) + int(hash) - - // Add to this row - cms.matrix[idx] += count - ret = min(ret, cms.matrix[idx]) - } - return ret -} - -// Get returns the count for the provided value. -func (cms *CountMinSketch) Get(val []byte) uint64 { - var ( - mh maphash.Hash - ret uint64 = math.MaxUint64 - ) - for i, seed := range cms.hashes { - mh.SetSeed(seed) - - // Generate a hash for this value using Lemire's alternative to modular reduction: - // https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/ - mh.Write(val) - hash := mh.Sum64() - hash = multiplyHigh64(hash, uint64(cms.nbuckets)) - - // The index in our matrix is (i * buckets) to move "down" i - // rows in our matrix to the row for this hash, plus 'hash' to - // move inside this row. - idx := (i * cms.nbuckets) + int(hash) - - // Select the minimal value among all rows - ret = min(ret, cms.matrix[idx]) - } - return ret -} - -// multiplyHigh64 implements (x * y) >> 64 "the long way" without access to a -// 128-bit type. This function is adapted from something similar in Tensorflow: -// -// https://github.com/tensorflow/tensorflow/commit/a47a300185026fe7829990def9113bf3a5109fed -// -// TODO(andrew-d): this could be replaced with a single "MULX" instruction on -// x86_64 platforms, which we can do if this ever turns out to be a performance -// bottleneck. -func multiplyHigh64(x, y uint64) uint64 { - x_lo := x & 0xffffffff - x_hi := x >> 32 - buckets_lo := y & 0xffffffff - buckets_hi := y >> 32 - prod_hi := x_hi * buckets_hi - prod_lo := x_lo * buckets_lo - prod_mid1 := x_hi * buckets_lo - prod_mid2 := x_lo * buckets_hi - carry := ((prod_mid1 & 0xffffffff) + (prod_mid2 & 0xffffffff) + (prod_lo >> 32)) >> 32 - return prod_hi + (prod_mid1 >> 32) + (prod_mid2 >> 32) + carry -} - -type mhValue[T any] struct { - count uint64 - val T -} - -// An minHeap is a min-heap of ints and associated values. -type minHeap[T any] []mhValue[T] - -func (h minHeap[T]) Len() int { return len(h) } -func (h minHeap[T]) Less(i, j int) bool { return h[i].count < h[j].count } -func (h minHeap[T]) Swap(i, j int) { h[i], h[j] = h[j], h[i] } - -func (h *minHeap[T]) Push(x any) { - // Push and Pop use pointer receivers because they modify the slice's length, - // not just its contents. - *h = append(*h, x.(mhValue[T])) -} - -func (h *minHeap[T]) Pop() any { - old := *h - n := len(old) - x := old[n-1] - *h = old[0 : n-1] - return x -} diff --git a/util/topk/topk_test.go b/util/topk/topk_test.go deleted file mode 100644 index 06656c4204fe6..0000000000000 --- a/util/topk/topk_test.go +++ /dev/null @@ -1,135 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -package topk - -import ( - "encoding/binary" - "fmt" - "slices" - "testing" -) - -func TestCountMinSketch(t *testing.T) { - cms := NewCountMinSketch(4, 10) - items := []string{"foo", "bar", "baz", "asdf", "quux"} - for _, item := range items { - cms.Add([]byte(item)) - } - for _, item := range items { - count := cms.Get([]byte(item)) - if count < 1 { - t.Errorf("item %q should have count >= 1", item) - } else if count > 1 { - t.Logf("item %q has count > 1: %d", item, count) - } - } - - // Test that an item that's *not* in the set has a value lower than the - // total number of items we inserted (in the case that all items - // collided). - noItemCount := cms.Get([]byte("doesn't exist")) - if noItemCount > uint64(len(items)) { - t.Errorf("expected nonexistent item to have value < %d; got %d", len(items), noItemCount) - } -} - -func TestTopK(t *testing.T) { - // This is probabilistic, so we're going to try 10 times to get the - // "right" value; the likelihood that we fail on all attempts is - // vanishingly small since the number of hash buckets is drastically - // larger than the number of items we're inserting. - var ( - got []int - want = []int{5, 6, 7, 8, 9} - ) - for try := 0; try < 10; try++ { - topk := NewWithParams[int](5, func(in []byte, val int) []byte { - return binary.LittleEndian.AppendUint64(in, uint64(val)) - }, 4, 1000) - - // Add the first 10 integers with counts equal to 2x their value - for i := range 10 { - topk.AddN(i, uint64(i*2)) - } - - got = topk.Top() - t.Logf("top K items: %+v", got) - slices.Sort(got) - - if slices.Equal(got, want) { - // All good! - return - } - - // continue and retry or fail - } - - t.Errorf("top K mismatch\ngot: %v\nwant: %v", got, want) -} - -func TestPickParams(t *testing.T) { - hashes, buckets := PickParams( - 0.001, // 0.1% error rate - 0.001, // 0.1% chance of having an error, or 99.9% chance of not having an error - ) - t.Logf("hashes = %d, buckets = %d", hashes, buckets) -} - -func BenchmarkCountMinSketch(b *testing.B) { - cms := NewCountMinSketch(PickParams(0.001, 0.001)) - b.ResetTimer() - b.ReportAllocs() - - var enc [8]byte - for i := range b.N { - binary.LittleEndian.PutUint64(enc[:], uint64(i)) - cms.Add(enc[:]) - } -} - -func BenchmarkTopK(b *testing.B) { - for _, n := range []int{ - 10, - 128, - 256, - 1024, - 8192, - } { - b.Run(fmt.Sprintf("Top%d", n), func(b *testing.B) { - out := make([]int, 0, n) - topk := New[int](n, func(in []byte, val int) []byte { - return binary.LittleEndian.AppendUint64(in, uint64(val)) - }) - b.ResetTimer() - b.ReportAllocs() - - for i := range b.N { - topk.Add(i) - } - out = topk.AppendTop(out[:0]) // should not allocate - _ = out // appease linter - }) - } -} - -func TestMultiplyHigh64(t *testing.T) { - testCases := []struct { - x, y uint64 - want uint64 - }{ - {0, 0, 0}, - {0xffffffff, 0xffffffff, 0}, - {0x2, 0xf000000000000000, 1}, - {0x3, 0xf000000000000000, 2}, - {0x3, 0xf000000000000001, 2}, - {0x3, 0xffffffffffffffff, 2}, - {0xffffffffffffffff, 0xffffffffffffffff, 0xfffffffffffffffe}, - } - for _, tc := range testCases { - got := multiplyHigh64(tc.x, tc.y) - if got != tc.want { - t.Errorf("got multiplyHigh64(%x, %x) = %x, want %x", tc.x, tc.y, got, tc.want) - } - } -} diff --git a/util/vizerror/vizerror.go b/util/vizerror/vizerror.go index 479bd2de9e7c8..e0abe8f97d15e 100644 --- a/util/vizerror/vizerror.go +++ b/util/vizerror/vizerror.go @@ -77,6 +77,5 @@ func WrapWithMessage(wrapped error, publicMsg string) error { // As returns the first vizerror.Error in err's chain. func As(err error) (e Error, ok bool) { - ok = errors.As(err, &e) - return + return errors.AsType[Error](err) } diff --git a/util/zstdframe/zstd_test.go b/util/zstdframe/zstd_test.go index 302090b9951b8..c006a06fd9d39 100644 --- a/util/zstdframe/zstd_test.go +++ b/util/zstdframe/zstd_test.go @@ -128,7 +128,7 @@ func BenchmarkEncode(b *testing.B) { b.Run(bb.name, func(b *testing.B) { b.ReportAllocs() b.SetBytes(int64(len(src))) - for range b.N { + for b.Loop() { dst = AppendEncode(dst[:0], src, bb.opts...) } }) @@ -153,7 +153,7 @@ func BenchmarkDecode(b *testing.B) { b.Run(bb.name, func(b *testing.B) { b.ReportAllocs() b.SetBytes(int64(len(src))) - for range b.N { + for b.Loop() { dst = must.Get(AppendDecode(dst[:0], src, bb.opts...)) } }) @@ -169,16 +169,14 @@ func BenchmarkEncodeParallel(b *testing.B) { } b.Run(coder.name, func(b *testing.B) { b.ReportAllocs() - for range b.N { - var group sync.WaitGroup - for j := 0; j < numCPU; j++ { - group.Add(1) - go func(j int) { - defer group.Done() + for b.Loop() { + var wg sync.WaitGroup + for j := range numCPU { + wg.Go(func() { dsts[j] = coder.appendEncode(dsts[j][:0], src) - }(j) + }) } - group.Wait() + wg.Wait() } }) } @@ -194,16 +192,14 @@ func BenchmarkDecodeParallel(b *testing.B) { } b.Run(coder.name, func(b *testing.B) { b.ReportAllocs() - for range b.N { - var group sync.WaitGroup - for j := 0; j < numCPU; j++ { - group.Add(1) - go func(j int) { - defer group.Done() + for b.Loop() { + var wg sync.WaitGroup + for j := range numCPU { + wg.Go(func() { dsts[j] = must.Get(coder.appendDecode(dsts[j][:0], src)) - }(j) + }) } - group.Wait() + wg.Wait() } }) } diff --git a/version/cmdname.go b/version/cmdname.go index 8a4040f9718b9..8e6adb047c8cc 100644 --- a/version/cmdname.go +++ b/version/cmdname.go @@ -13,6 +13,7 @@ import ( "os" "path" "runtime" + "runtime/debug" "strings" ) @@ -20,6 +21,15 @@ import ( // using os.Executable. If os.Executable fails (it shouldn't), then // "cmd" is returned. func CmdName() string { + // On non-Windows, the modinfo embedded in the running binary is + // authoritative and avoids re-reading the executable from disk. + // Windows needs the executable-name-based GUI override in cmdName, + // so it still takes the slower path. + if runtime.GOOS != "windows" { + if info, ok := debug.ReadBuildInfo(); ok && info.Path != "" { + return path.Base(info.Path) + } + } e, err := os.Executable() if err != nil { return "cmd" @@ -39,7 +49,7 @@ func cmdName(exe string) string { } // v is like: // "path\ttailscale.com/cmd/tailscale\nmod\ttailscale.com\t(devel)\t\ndep\tgithub.com/apenwarr/fixconsole\tv0.0.0-20191012055117-5a9f6489cc29\th1:muXWUcay7DDy1/hEQWrYlBy+g0EuwT70sBHg65SeUc4=\ndep\tgithub.... - for _, line := range strings.Split(info, "\n") { + for line := range strings.SplitSeq(info, "\n") { if goPkg, ok := strings.CutPrefix(line, "path\t"); ok { // like "tailscale.com/cmd/tailscale" ret = path.Base(goPkg) // goPkg is always forward slashes; use path, not filepath break diff --git a/version/cmp.go b/version/cmp.go index 4af0aec69ea6e..6d44475e75743 100644 --- a/version/cmp.go +++ b/version/cmp.go @@ -103,6 +103,11 @@ func parse(version string) (parsed, bool) { } } + // Ignore trailer like '_1 (Void Linux)'. + if rest[0] == '_' && strings.HasSuffix(rest, " (Void Linux)") { + return ret, true + } + // Optional extraCommits, if the next bit can be completely // consumed as an integer. if rest[0] != '-' { diff --git a/version/cmp_test.go b/version/cmp_test.go index 10fc130b768eb..c93df1a7cebe2 100644 --- a/version/cmp_test.go +++ b/version/cmp_test.go @@ -33,6 +33,8 @@ func TestParse(t *testing.T) { {"borkbork", parsed{}, false}, {"1a.2.3", parsed{}, false}, {"", parsed{}, false}, + {"1.96.2_1 (Void Linux)", parsed{Major: 1, Minor: 96, Patch: 2}, true}, + {"1.46.0_2 (Void Linux)", parsed{Major: 1, Minor: 46, Patch: 0}, true}, } for _, test := range tests { @@ -71,6 +73,7 @@ func TestAtLeast(t *testing.T) { {"date.20200612", "date.20200612", true}, {"date.20200701", "date.20200612", true}, {"date.20200501", "date.20200612", false}, + {"1.96.2_1 (Void Linux)", "1.42", true}, } for _, test := range tests { diff --git a/version/print.go b/version/print.go index ca62226ee2b6d..3b4a256cf8781 100644 --- a/version/print.go +++ b/version/print.go @@ -24,7 +24,14 @@ var stringLazy = sync.OnceValue(func() string { if extraGitCommitStamp != "" { fmt.Fprintf(&ret, " other commit: %s\n", extraGitCommitStamp) } - fmt.Fprintf(&ret, " go version: %s\n", runtime.Version()) + if tsGoRev := tailscaleToolchainRev(); tsGoRev != "" { + if len(tsGoRev) > 10 { + tsGoRev = tsGoRev[:10] + } + fmt.Fprintf(&ret, " go version: %s (tailscale/go %s)\n", runtime.Version(), tsGoRev) + } else { + fmt.Fprintf(&ret, " go version: %s\n", runtime.Version()) + } return strings.TrimSpace(ret.String()) }) diff --git a/version/prop.go b/version/prop.go index 36d7699176f1e..59ca74086bd01 100644 --- a/version/prop.go +++ b/version/prop.go @@ -312,6 +312,11 @@ type Meta struct { // GitCommitTime is the commit time of the git commit in GitCommit. GitCommitTime string `json:"gitCommitTime,omitempty"` + // TailscaleGoGitHash is the git commit hash from + // https://github.com/tailscale/go used to build this binary, if built + // with the Tailscale Go toolchain. Otherwise it is empty. + TailscaleGoGitHash string `json:"tailscaleGoGitHash,omitempty"` + // Cap is the current Tailscale capability version. It's a monotonically // incrementing integer that's incremented whenever a new capability is // added. @@ -324,17 +329,18 @@ var getMeta lazy.SyncValue[Meta] func GetMeta() Meta { return getMeta.Get(func() Meta { return Meta{ - MajorMinorPatch: majorMinorPatch(), - Short: Short(), - Long: Long(), - GitCommitTime: getEmbeddedInfo().commitTime, - GitCommit: gitCommit(), - GitDirty: gitDirty(), - OSVariant: osVariant(), - ExtraGitCommit: extraGitCommitStamp, - IsDev: isDev(), - UnstableBranch: IsUnstableBuild(), - Cap: int(tailcfg.CurrentCapabilityVersion), + MajorMinorPatch: majorMinorPatch(), + Short: Short(), + Long: Long(), + GitCommitTime: getEmbeddedInfo().commitTime, + GitCommit: gitCommit(), + GitDirty: gitDirty(), + OSVariant: osVariant(), + ExtraGitCommit: extraGitCommitStamp, + IsDev: isDev(), + UnstableBranch: IsUnstableBuild(), + TailscaleGoGitHash: tailscaleToolchainRev(), + Cap: int(tailcfg.CurrentCapabilityVersion), } }) } diff --git a/version/version.go b/version/version.go index 1171ed2ffe722..8ffc218321357 100644 --- a/version/version.go +++ b/version/version.go @@ -146,6 +146,23 @@ var getEmbeddedInfo = sync.OnceValue(func() embeddedInfo { return ret }) +// tailscaleToolchainRev returns the git hash of the Tailscale Go toolchain +// used to build this binary, if any. It is read separately from getEmbeddedInfo +// because that function discards build info when VCS fields are missing (e.g. +// in test binaries), but the toolchain rev is still present. +var tailscaleToolchainRev = sync.OnceValue(func() string { + bi, ok := debug.ReadBuildInfo() + if !ok { + return "" + } + for _, s := range bi.Settings { + if s.Key == "tailscale.toolchain.rev" { + return s.Value + } + } + return "" +}) + func gitCommit() string { if gitCommitStamp != "" { return gitCommitStamp @@ -172,6 +189,10 @@ func majorMinorPatch() string { return ret } +// IsTailscaleGo reports whether the current binary was built with +// Tailscale's custom Go toolchain. +func IsTailscaleGo() bool { return isTailscaleGo } + func isValidLongWithTwoRepos(v string) bool { s := strings.Split(v, "-") if len(s) != 3 { diff --git a/version/version_internal_test.go b/version/version_internal_test.go index c78df4ff81a70..72b2dcd5f1f5d 100644 --- a/version/version_internal_test.go +++ b/version/version_internal_test.go @@ -3,7 +3,13 @@ package version -import "testing" +import ( + "os/exec" + "strings" + "testing" + + "tailscale.com/util/cibuild" +) func TestIsValidLongWithTwoRepos(t *testing.T) { tests := []struct { @@ -26,6 +32,26 @@ func TestIsValidLongWithTwoRepos(t *testing.T) { } } +func TestTailscaleToolchainRev(t *testing.T) { + out, err := exec.Command("go", "env", "GOROOT").Output() + if err != nil { + t.Fatalf("go env GOROOT: %v", err) + } + goRoot := strings.TrimSpace(string(out)) + isTsgo := strings.Contains(goRoot, "/.cache/tsgo/") + if !cibuild.On() && !isTsgo { + t.Skip("skipping; not in CI and not using the Tailscale Go toolchain") + } + if !isTailscaleGo { + t.Skip("skipping; not built with tailscale_go build tag") + } + rev := tailscaleToolchainRev() + if rev == "" { + t.Fatal("tailscale.toolchain.rev is empty in build info; expected non-empty when using tsgo") + } + t.Logf("tailscale.toolchain.rev = %s", rev) +} + func TestPrepExeNameForCmp(t *testing.T) { cases := []struct { exe string diff --git a/version/version_not_tsgo.go b/version/version_not_tsgo.go new file mode 100644 index 0000000000000..a852964ab106c --- /dev/null +++ b/version/version_not_tsgo.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !tailscale_go + +package version + +const isTailscaleGo = false diff --git a/version/version_test.go b/version/version_test.go index ebae7f177613a..01fcd47ecc0b5 100644 --- a/version/version_test.go +++ b/version/version_test.go @@ -6,6 +6,8 @@ package version_test import ( "bytes" "os" + "path" + "runtime/debug" "testing" ts "tailscale.com" @@ -30,7 +32,7 @@ func readAlpineTag(t *testing.T, file string) string { if err != nil { t.Fatal(err) } - for _, line := range bytes.Split(f, []byte{'\n'}) { + for line := range bytes.SplitSeq(f, []byte{'\n'}) { line = bytes.TrimSpace(line) _, suf, ok := bytes.Cut(line, []byte("FROM alpine:")) if !ok { @@ -49,3 +51,21 @@ func TestShortAllocs(t *testing.T) { t.Errorf("allocs = %v; want 0", allocs) } } + +func BenchmarkCmdName(b *testing.B) { + b.ReportAllocs() + for b.Loop() { + _ = version.CmdName() + } +} + +func BenchmarkReadBuildInfo(b *testing.B) { + b.ReportAllocs() + for b.Loop() { + info, ok := debug.ReadBuildInfo() + if !ok { + b.Fatal("ReadBuildInfo failed") + } + _ = path.Base(info.Path) + } +} diff --git a/version/version_tsgo.go b/version/version_tsgo.go new file mode 100644 index 0000000000000..fd72af7d4a733 --- /dev/null +++ b/version/version_tsgo.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build tailscale_go + +package version + +const isTailscaleGo = true diff --git a/wgengine/filter/filter_test.go b/wgengine/filter/filter_test.go index c588a506e0dc9..a3b9a8e001e60 100644 --- a/wgengine/filter/filter_test.go +++ b/wgengine/filter/filter_test.go @@ -751,13 +751,13 @@ func ports(s string) PortRange { } var fs, ls string - i := strings.IndexByte(s, '-') - if i == -1 { + before, after, ok := strings.Cut(s, "-") + if !ok { fs = s ls = fs } else { - fs = s[:i] - ls = s[i+1:] + fs = before + ls = after } first, err := strconv.ParseInt(fs, 10, 16) if err != nil { diff --git a/wgengine/magicsock/debughttp.go b/wgengine/magicsock/debughttp.go index 68019d0a76cbb..a9f4734f9653e 100644 --- a/wgengine/magicsock/debughttp.go +++ b/wgengine/magicsock/debughttp.go @@ -108,8 +108,8 @@ func (c *Conn) ServeHTTPDebug(w http.ResponseWriter, r *http.Request) { } sort.Slice(ent, func(i, j int) bool { return ent[i].pub.Less(ent[j].pub) }) - peers := map[key.NodePublic]tailcfg.NodeView{} - for _, p := range c.peers.All() { + peers := make(map[key.NodePublic]tailcfg.NodeView, len(c.peersByID)) + for _, p := range c.peersByID { peers[p.Key()] = p } diff --git a/wgengine/magicsock/derp.go b/wgengine/magicsock/derp.go index f9e5050705b31..72c75db5a835a 100644 --- a/wgengine/magicsock/derp.go +++ b/wgengine/magicsock/derp.go @@ -6,6 +6,7 @@ package magicsock import ( "bufio" "context" + "crypto/tls" "fmt" "maps" "net" @@ -101,6 +102,7 @@ type activeDerp struct { var ( pickDERPFallbackForTests func() int + reSTUNHookForTests func(why string) ) // pickDERPFallback returns a non-zero but deterministic DERP node to @@ -154,7 +156,7 @@ var checkControlHealthDuringNearestDERPInTests = false // region that it selected and set (via setNearestDERP). // // c.mu must NOT be held. -func (c *Conn) maybeSetNearestDERP(report *netcheck.Report) (preferredDERP int) { +func (c *Conn) maybeSetNearestDERP(report *netcheck.Report, force bool) (preferredDERP int) { // Don't change our PreferredDERP if we don't have a connection to // control; if we don't, then we can't inform peers about a DERP home // change, which breaks all connectivity. Even if this DERP region is @@ -168,7 +170,10 @@ func (c *Conn) maybeSetNearestDERP(report *netcheck.Report) (preferredDERP int) // // Despite the above behaviour, ensure that we set the nearest DERP if // we don't currently have one set; any DERP server is better than - // none, even if not connected to control. + // none, even if not connected to control. The exception here is if we have + // a cached netmap with a previous DERP server. Retaining the previous DERP + // makes it easier for other nodes to find each other when control is not + // available. var connectedToControl bool if testenv.InTest() && !checkControlHealthDuringNearestDERPInTests { connectedToControl = true @@ -178,7 +183,7 @@ func (c *Conn) maybeSetNearestDERP(report *netcheck.Report) (preferredDERP int) c.mu.Lock() myDerp := c.myDerp c.mu.Unlock() - if !connectedToControl { + if !connectedToControl && !force { if myDerp != 0 { metricDERPHomeNoChangeNoControl.Add(1) return myDerp @@ -197,15 +202,32 @@ func (c *Conn) maybeSetNearestDERP(report *netcheck.Report) (preferredDERP int) } if preferredDERP != myDerp { c.logf( - "magicsock: home DERP changing from derp-%d [%dms] to derp-%d [%dms]", - c.myDerp, report.RegionLatency[myDerp].Milliseconds(), preferredDERP, report.RegionLatency[preferredDERP].Milliseconds()) + "magicsock: home DERP changing from derp-%d [%dms] to derp-%d [%dms] (forced=%t)", + c.myDerp, report.RegionLatency[myDerp].Milliseconds(), preferredDERP, report.RegionLatency[preferredDERP].Milliseconds(), force) } if !c.setNearestDERP(preferredDERP) { preferredDERP = 0 + } else if preferredDERP != myDerp { + c.homeDERPChangedPub.Publish(HomeDERPChanged{Old: myDerp, New: preferredDERP}) } return } +// HomeDERPChanged is an event sent on the [eventbus.Bus] when a new home DERP +// server has been selected. Its publisher is [magicsock.Coon]; its main +// subscriber is [ipnlocal.LocalBackend] that updates the homeDERP used by the +// netmap cache. +// TODO(cmol): Move the subscriber to not inject into localBackend, but rather +// into the netmap at the controlClient mapSession level once there is a stable +// abstraction to use. +type HomeDERPChanged struct { + Old, New int +} + +func (c *Conn) ForceSetNearestDERP(regionID int) int { + return c.maybeSetNearestDERP(&netcheck.Report{PreferredDERP: regionID}, true) +} + func (c *Conn) derpRegionCodeLocked(regionID int) string { if c.derpMap == nil { return "" @@ -392,6 +414,9 @@ func (c *Conn) derpWriteChanForRegion(regionID int, peer key.NodePublic) chan de return derpMap.Regions[regionID] }) dc.HealthTracker = c.health + if c.extraRootCAs != nil { + dc.TLSConfig = &tls.Config{RootCAs: c.extraRootCAs} + } dc.SetCanAckPings(true) dc.NotePreferred(c.myDerp == regionID) @@ -725,6 +750,10 @@ func (c *Conn) processDERPReadResult(dm derpReadResult, b []byte) (n int, ep *en return 0, nil } + if c.onDERPRecv != nil && c.onDERPRecv(regionID, dm.src, b[:n]) { + return 0, nil + } + var ok bool c.mu.Lock() ep, ok = c.peerMap.endpointForNodeKey(dm.src) @@ -745,6 +774,15 @@ func (c *Conn) processDERPReadResult(dm derpReadResult, b []byte) (n int, ep *en return n, ep } +// SendDERPPacketTo sends an arbitrary packet to the given node key via +// the DERP relay for the given region. It creates the DERP connection +// to the region if one doesn't already exist. +func (c *Conn) SendDERPPacketTo(dstKey key.NodePublic, regionID int, pkt []byte) (sent bool, err error) { + return c.sendAddr( + netip.AddrPortFrom(tailcfg.DerpMagicIPAddr, uint16(regionID)), + dstKey, pkt, false, false) +} + // SetOnlyTCP443 set whether the magicsock connection is restricted // to only using TCP port 443 outbound. If true, no UDP is allowed, // no STUN checks are performend, etc. @@ -754,7 +792,24 @@ func (c *Conn) SetOnlyTCP443(v bool) { // SetDERPMap controls which (if any) DERP servers are used. // A nil value means to disable DERP; it's disabled by default. +// +// SetDERPMap triggers a ReSTUN after updating the map. Callers that want to +// set the map without triggering a ReSTUN should use [Conn.SetDERPMapWithoutReSTUN] +// instead. func (c *Conn) SetDERPMap(dm *tailcfg.DERPMap) { + c.setDERPMap(dm, true) +} + +// SetDERPMapWithoutReSTUN is like [Conn.SetDERPMap] but does not trigger a +// ReSTUN after updating the map. +// +// It is used for setting the map from a cache, so the homeDERP can be set +// from cache before any STUN happens. +func (c *Conn) SetDERPMapWithoutReSTUN(dm *tailcfg.DERPMap) { + c.setDERPMap(dm, false) +} + +func (c *Conn) setDERPMap(dm *tailcfg.DERPMap, doReStun bool) { c.mu.Lock() defer c.mu.Unlock() @@ -811,8 +866,14 @@ func (c *Conn) SetDERPMap(dm *tailcfg.DERPMap) { } } - go c.ReSTUN("derp-map-update") + if doReStun { + if reSTUNHookForTests != nil { + reSTUNHookForTests("derp-map-update") + } + go c.ReSTUN("derp-map-update") + } } + func (c *Conn) wantDerpLocked() bool { return c.derpMap != nil } // c.mu must be held. diff --git a/wgengine/magicsock/derp_test.go b/wgengine/magicsock/derp_test.go index 084f710d8526d..c79882d54c2e5 100644 --- a/wgengine/magicsock/derp_test.go +++ b/wgengine/magicsock/derp_test.go @@ -4,9 +4,15 @@ package magicsock import ( + "fmt" "testing" + "tailscale.com/health" "tailscale.com/net/netcheck" + "tailscale.com/tailcfg" + "tailscale.com/tstest" + "tailscale.com/util/eventbus" + "tailscale.com/util/eventbus/eventbustest" ) func CheckDERPHeuristicTimes(t *testing.T) { @@ -14,3 +20,111 @@ func CheckDERPHeuristicTimes(t *testing.T) { t.Errorf("PreferredDERPFrameTime too low; should be at least frameReceiveRecordRate") } } + +func TestForceSetNearestDERP(t *testing.T) { + derpMap := &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 7: { + RegionID: 7, + RegionCode: "test", + Nodes: []*tailcfg.DERPNode{ + { + Name: "7a", + RegionID: 7, + HostName: "derp7.test.unused", + IPv4: "127.0.0.1", + IPv6: "none", + }, + }, + }, + }, + } + + // Force the real control health check so we can verify force=true bypasses it. + tstest.Replace(t, &checkControlHealthDuringNearestDERPInTests, true) + + bus := eventbustest.NewBus(t) + ht := health.NewTracker(bus) + c := newConn(t.Logf) + ec := bus.Client("magicsock.Conn.Test") + c.eventClient = ec + c.homeDERPChangedPub = eventbus.Publish[HomeDERPChanged](ec) + c.eventBus = bus + c.derpMap = derpMap + c.health = ht + + ht.SetOutOfPollNetMap() + + tw := eventbustest.NewWatcher(t, bus) + + got := c.ForceSetNearestDERP(7) + if got != 7 { + t.Fatalf("ForceSetNearestDERP(7) = %d, want 7", got) + } + if c.myDerp != 7 { + t.Errorf("c.myDerp = %d after ForceSetNearestDERP, want 7", c.myDerp) + } + + if err := eventbustest.Expect(tw, func(e HomeDERPChanged) error { + if e.Old != 0 || e.New != 7 { + return fmt.Errorf("got HomeDERPChanged{Old:%d, New:%d}, want {Old:0, New:7}", e.Old, e.New) + } + return nil + }); err != nil { + t.Errorf("expected HomeDERPChanged event: %v", err) + } +} + +func TestSetDERPMapDoReStun(t *testing.T) { + derpMap1 := &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 1: { + RegionID: 1, + RegionCode: "cph", + Nodes: []*tailcfg.DERPNode{ + {Name: "1a", RegionID: 1, HostName: "cph.test.unused", IPv4: "127.0.0.1", IPv6: "none"}, + }, + }, + }, + } + derpMap2 := &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 2: { + RegionID: 2, + RegionCode: "inc", + Nodes: []*tailcfg.DERPNode{ + {Name: "2a", RegionID: 2, HostName: "inc.test.unused", IPv4: "127.0.0.1", IPv6: "none"}, + }, + }, + }, + } + + var reSTUNCalls int + tstest.Replace(t, &reSTUNHookForTests, func(_ string) { + reSTUNCalls++ + }) + + bus := eventbustest.NewBus(t) + ht := health.NewTracker(bus) + c := newConn(t.Logf) + ec := bus.Client("magicsock.Conn.Test") + c.eventClient = ec + c.homeDERPChangedPub = eventbus.Publish[HomeDERPChanged](ec) + c.eventBus = bus + c.health = ht + // With a zero private key and everHadKey=true, ReSTUN returns early without + // spawning updateEndpoints. + c.everHadKey = true + + // SetDERPMapWithoutReSTUN should not trigger a ReSTUN. + c.SetDERPMapWithoutReSTUN(derpMap1) + if reSTUNCalls != 0 { + t.Errorf("SetDERPMapWithoutReSTUN: got %d ReSTUN calls, want 0", reSTUNCalls) + } + + // SetDERPMap should trigger a ReSTUN. + c.SetDERPMap(derpMap2) + if reSTUNCalls != 1 { + t.Errorf("SetDERPMap: got %d ReSTUN calls, want 1", reSTUNCalls) + } +} diff --git a/wgengine/magicsock/endpoint.go b/wgengine/magicsock/endpoint.go index 1f99f57ec2d16..71edfe9a1e249 100644 --- a/wgengine/magicsock/endpoint.go +++ b/wgengine/magicsock/endpoint.go @@ -40,6 +40,11 @@ import ( var mtuProbePingSizesV4 []int var mtuProbePingSizesV6 []int +// discoKeyAdvertisementInterval tells how often a disco update via TSMP can +// happen. The update is triggered via enqueueCallMeMaybe, and thus it will +// only be sent if the magicsock is in a state to send out CallMeMaybe. +const discoKeyAdvertisementInterval = time.Second * 60 + func init() { for _, m := range tstun.WireMTUsToProbe { mtuProbePingSizesV4 = append(mtuProbePingSizesV4, pktLenToPingSize(m, false)) @@ -80,7 +85,7 @@ type endpoint struct { lastSendAny mono.Time // last time there were outgoing packets sent this peer from any trigger, internal or external to magicsock lastFullPing mono.Time // last time we pinged all disco or wireguard only endpoints lastUDPRelayPathDiscovery mono.Time // last time we ran UDP relay path discovery - sentDiscoKeyAdvertisement bool // wether we sent a TSMPDiscoAdvertisement or not to this endpoint + lastDiscoKeyAdvertisement mono.Time // last time we sent a TSMPDiscoAdvertisement or not to this endpoint derpAddr netip.AddrPort // fallback/bootstrap path, if non-zero (non-zero for well-behaved clients) bestAddr addrQuality // best non-DERP path; zero if none; mutate via setBestAddrLocked() @@ -525,11 +530,6 @@ func (de *endpoint) noteRecvActivity(src epAddr, now mono.Time) bool { elapsed := now.Sub(de.lastRecvWG.LoadAtomic()) if elapsed > 10*time.Second { de.lastRecvWG.StoreAtomic(now) - - if de.c.noteRecvActivity == nil { - return false - } - de.c.noteRecvActivity(de.publicKey) return true } return false @@ -892,7 +892,7 @@ func (de *endpoint) wantUDPRelayPathDiscoveryLocked(now mono.Time) bool { if runtime.GOOS == "js" { return false } - if !de.c.hasPeerRelayServers.Load() { + if !de.c.relayManager.hasPeerRelayServers.Load() { // Changes in this value between its access and a call to // [endpoint.discoverUDPRelayPathsLocked] are fine, we will eventually // do the "right" thing during future path discovery. The worst case is @@ -1178,8 +1178,7 @@ func (de *endpoint) discoPingTimeout(txid stun.TxID) { return } bestUntrusted := mono.Now().After(de.trustBestAddrUntil) - if sp.to == de.bestAddr.epAddr && sp.to.vni.IsSet() && bestUntrusted { - // TODO(jwhited): consider applying this to direct UDP paths as well + if sp.to == de.bestAddr.epAddr && bestUntrusted { de.clearBestAddrLocked() } if debugDisco() || !de.bestAddr.ap.IsValid() || bestUntrusted { @@ -1465,6 +1464,19 @@ func (de *endpoint) setLastPing(ipp netip.AddrPort, now mono.Time) { state.lastPing = now } +// updateDiscoKey replaces the disco key for de. If the key is a zero value key, +// set the key to nil. +func (de *endpoint) updateDiscoKey(key key.DiscoPublic) { + if key.IsZero() { + de.disco.Store(nil) + } else { + de.disco.Store(&endpointDisco{ + key: key, + short: key.ShortString(), + }) + } +} + // updateFromNode updates the endpoint based on a tailcfg.Node from a NetMap // update. func (de *endpoint) updateFromNode(n tailcfg.NodeView, heartbeatDisabled bool, probeUDPLifetimeEnabled bool) { @@ -1490,15 +1502,12 @@ func (de *endpoint) updateFromNode(n tailcfg.NodeView, heartbeatDisabled bool, p if discoKey != n.DiscoKey() { de.c.logf("[v1] magicsock: disco: node %s changed from %s to %s", de.publicKey.ShortString(), discoKey, n.DiscoKey()) - de.disco.Store(&endpointDisco{ - key: n.DiscoKey(), - short: n.DiscoKey().ShortString(), - }) + key := n.DiscoKey() + de.updateDiscoKey(key) de.debugUpdates.Add(EndpointChange{ When: time.Now(), What: "updateFromNode-resetLocked", }) - de.resetLocked() } if n.HomeDERP() == 0 { if de.derpAddr.IsValid() { @@ -1763,12 +1772,8 @@ func (de *endpoint) handlePongConnLocked(m *disco.Pong, di *discoInfo, src epAdd latency: latency, wireMTU: pingSizeToPktLen(sp.size, sp.to), } - // TODO(jwhited): consider checking de.trustBestAddrUntil as well. If - // de.bestAddr is untrusted we may want to clear it, otherwise we could - // get stuck with a forever untrusted bestAddr that blackholes, since - // we don't clear direct UDP paths on disco ping timeout (see - // discoPingTimeout). - if betterAddr(thisPong, de.bestAddr) { + bestUntrusted := now.After(de.trustBestAddrUntil) + if betterAddr(thisPong, de.bestAddr) || bestUntrusted { de.c.logf("magicsock: disco: node %v %v now using %v mtu=%v tx=%x", de.publicKey.ShortString(), de.discoShort(), sp.to, thisPong.wireMTU, m.TxID[:6]) de.debugUpdates.Add(EndpointChange{ When: time.Now(), @@ -2083,7 +2088,7 @@ func (de *endpoint) setDERPHome(regionID uint16) { de.mu.Lock() defer de.mu.Unlock() de.derpAddr = netip.AddrPortFrom(tailcfg.DerpMagicIPAddr, uint16(regionID)) - if de.c.hasPeerRelayServers.Load() { + if de.c.relayManager.hasPeerRelayServers.Load() { de.c.relayManager.handleDERPHomeChange(de.publicKey, regionID) } } diff --git a/wgengine/magicsock/endpoint_test.go b/wgengine/magicsock/endpoint_test.go index 43ff012c73d61..593cf1455b0c6 100644 --- a/wgengine/magicsock/endpoint_test.go +++ b/wgengine/magicsock/endpoint_test.go @@ -6,12 +6,16 @@ package magicsock import ( "net/netip" "testing" + "testing/synctest" "time" + "tailscale.com/disco" "tailscale.com/net/packet" + "tailscale.com/net/stun" "tailscale.com/tailcfg" "tailscale.com/tstime/mono" "tailscale.com/types/key" + "tailscale.com/util/ringlog" ) func TestProbeUDPLifetimeConfig_Equals(t *testing.T) { @@ -180,7 +184,7 @@ func Test_endpoint_maybeProbeUDPLifetimeLocked(t *testing.T) { wantMaybe bool }{ { - name: "nil probeUDPLifetime", + name: "nil-probeUDPLifetime", localDisco: higher, remoteDisco: &lower, probeUDPLifetimeFn: func() *probeUDPLifetime { @@ -189,28 +193,28 @@ func Test_endpoint_maybeProbeUDPLifetimeLocked(t *testing.T) { bestAddr: addr, }, { - name: "local higher disco key", + name: "local-higher-disco-key", localDisco: higher, remoteDisco: &lower, probeUDPLifetimeFn: newProbeUDPLifetime, bestAddr: addr, }, { - name: "remote no disco key", + name: "remote-no-disco-key", localDisco: higher, remoteDisco: nil, probeUDPLifetimeFn: newProbeUDPLifetime, bestAddr: addr, }, { - name: "invalid bestAddr", + name: "invalid-bestAddr", localDisco: lower, remoteDisco: &higher, probeUDPLifetimeFn: newProbeUDPLifetime, bestAddr: addrQuality{}, }, { - name: "cycle started too recently", + name: "cycle-started-too-recently", localDisco: lower, remoteDisco: &higher, probeUDPLifetimeFn: func() *probeUDPLifetime { @@ -222,7 +226,7 @@ func Test_endpoint_maybeProbeUDPLifetimeLocked(t *testing.T) { bestAddr: addr, }, { - name: "maybe cliff 0 cycle not active", + name: "maybe-cliff-0-cycle-not-active", localDisco: lower, remoteDisco: &higher, probeUDPLifetimeFn: func() *probeUDPLifetime { @@ -238,7 +242,7 @@ func Test_endpoint_maybeProbeUDPLifetimeLocked(t *testing.T) { wantMaybe: true, }, { - name: "maybe cliff 0", + name: "maybe-cliff-0", localDisco: lower, remoteDisco: &higher, probeUDPLifetimeFn: func() *probeUDPLifetime { @@ -254,7 +258,7 @@ func Test_endpoint_maybeProbeUDPLifetimeLocked(t *testing.T) { wantMaybe: true, }, { - name: "maybe cliff 1", + name: "maybe-cliff-1", localDisco: lower, remoteDisco: &higher, probeUDPLifetimeFn: func() *probeUDPLifetime { @@ -270,7 +274,7 @@ func Test_endpoint_maybeProbeUDPLifetimeLocked(t *testing.T) { wantMaybe: true, }, { - name: "maybe cliff 2", + name: "maybe-cliff-2", localDisco: lower, remoteDisco: &higher, probeUDPLifetimeFn: func() *probeUDPLifetime { @@ -341,13 +345,13 @@ func Test_epAddr_isDirectUDP(t *testing.T) { want: true, }, { - name: "false derp magic addr", + name: "false-derp-magic-addr", ap: netip.AddrPortFrom(tailcfg.DerpMagicIPAddr, 0), vni: packet.VirtualNetworkID{}, want: false, }, { - name: "false vni set", + name: "false-vni-set", ap: netip.MustParseAddrPort("192.0.2.1:7"), vni: vni, want: false, @@ -397,42 +401,42 @@ func Test_endpoint_udpRelayEndpointReady(t *testing.T) { wantBestAddr addrQuality }{ { - name: "bestAddr trusted direct", + name: "bestAddr-trusted-direct", curBestAddr: directAddrQuality, trustBestAddrUntil: mono.Now().Add(1 * time.Hour), maybeBest: peerRelayAddrQuality, wantBestAddr: directAddrQuality, }, { - name: "bestAddr untrusted direct", + name: "bestAddr-untrusted-direct", curBestAddr: directAddrQuality, trustBestAddrUntil: mono.Now().Add(-1 * time.Hour), maybeBest: peerRelayAddrQuality, wantBestAddr: peerRelayAddrQuality, }, { - name: "maybeBest same relay server higher latency bestAddr trusted", + name: "maybeBest-same-relay-higher-latency-trusted", curBestAddr: peerRelayAddrQuality, trustBestAddrUntil: mono.Now().Add(1 * time.Hour), maybeBest: peerRelayAddrQualityHigherLatencySameServer, wantBestAddr: peerRelayAddrQualityHigherLatencySameServer, }, { - name: "maybeBest diff relay server higher latency bestAddr trusted", + name: "maybeBest-diff-relay-higher-latency-trusted", curBestAddr: peerRelayAddrQuality, trustBestAddrUntil: mono.Now().Add(1 * time.Hour), maybeBest: peerRelayAddrQualityHigherLatencyDiffServer, wantBestAddr: peerRelayAddrQuality, }, { - name: "maybeBest diff relay server lower latency bestAddr trusted", + name: "maybeBest-diff-relay-lower-latency-trusted", curBestAddr: peerRelayAddrQuality, trustBestAddrUntil: mono.Now().Add(1 * time.Hour), maybeBest: peerRelayAddrQualityLowerLatencyDiffServer, wantBestAddr: peerRelayAddrQualityLowerLatencyDiffServer, }, { - name: "maybeBest diff relay server equal latency bestAddr trusted", + name: "maybeBest-diff-relay-equal-latency-trusted", curBestAddr: peerRelayAddrQuality, trustBestAddrUntil: mono.Now().Add(1 * time.Hour), maybeBest: peerRelayAddrQualityEqualLatencyDiffServer, @@ -453,3 +457,233 @@ func Test_endpoint_udpRelayEndpointReady(t *testing.T) { }) } } + +func Test_endpoint_discoPingTimeout(t *testing.T) { + expired := -1 * time.Hour + valid := 1 * time.Hour + directAddrA := epAddr{ap: netip.MustParseAddrPort("192.0.2.1:7")} + relayAddrA := epAddr{ap: netip.MustParseAddrPort("192.0.2.2:77")} + relayAddrA.vni.Set(1) + directAddrB := epAddr{ap: netip.MustParseAddrPort("192.0.2.3:7")} + relayAddrB := epAddr{ap: netip.MustParseAddrPort("192.0.2.4:77")} + relayAddrB.vni.Set(1) + + for _, tc := range []struct { + name string + bestAddr addrQuality + trustBestAddrUntil time.Duration + pingTo epAddr + wantBestAddrCleared bool + }{ + { + name: "relay-path-trust-expired", + bestAddr: addrQuality{epAddr: relayAddrA}, + trustBestAddrUntil: expired, + pingTo: relayAddrA, + wantBestAddrCleared: true, + }, + { + name: "direct-udp-path-trust-expired", + bestAddr: addrQuality{epAddr: directAddrA}, + trustBestAddrUntil: expired, + pingTo: directAddrA, + wantBestAddrCleared: true, + }, + { + name: "direct-udp-path-trust-valid", + bestAddr: addrQuality{epAddr: directAddrA}, + trustBestAddrUntil: valid, + pingTo: directAddrA, + wantBestAddrCleared: false, + }, + { + name: "relay-path-trust-valid", + bestAddr: addrQuality{epAddr: relayAddrA}, + trustBestAddrUntil: valid, + pingTo: relayAddrA, + wantBestAddrCleared: false, + }, + { + name: "ping-to-different-direct-addr-trust-expired", + bestAddr: addrQuality{epAddr: directAddrA}, + trustBestAddrUntil: expired, + pingTo: directAddrB, + wantBestAddrCleared: false, + }, + { + name: "ping-to-different-relay-addr-trust-expired", + bestAddr: addrQuality{epAddr: relayAddrA}, + trustBestAddrUntil: expired, + pingTo: relayAddrB, + wantBestAddrCleared: false, + }, + } { + t.Run(tc.name, func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + now := mono.Now() // synctest to match this to the internal 'now' + c := &Conn{ + logf: func(msg string, args ...any) {}, + } + c.discoAtomic.Set(key.NewDisco()) + de := &endpoint{ + c: c, + bestAddr: tc.bestAddr, + trustBestAddrUntil: now.Add(tc.trustBestAddrUntil), + sentPing: make(map[stun.TxID]sentPing), + } + txid := stun.NewTxID() + timer := time.NewTimer(time.Hour) + timer.Stop() + de.sentPing[txid] = sentPing{ + to: tc.pingTo, + at: now.Add(-100 * time.Millisecond), + timer: timer, + purpose: pingDiscovery, + } + + de.discoPingTimeout(txid) + if tc.wantBestAddrCleared { + if de.bestAddr.ap.IsValid() { + t.Errorf("expected bestAddr to be cleared, but bestAddr.ap is valid: %v", de.bestAddr.ap) + } + if de.trustBestAddrUntil != 0 { + t.Errorf("expected trustBestAddrUntil to be cleared, but got: %v", de.trustBestAddrUntil) + } + } else { + if de.bestAddr != tc.bestAddr { + t.Errorf("expected bestAddr to be unchanged, got: %v, want: %v", de.bestAddr, tc.bestAddr) + } + } + if _, ok := de.sentPing[txid]; ok { + t.Errorf("expected sentPing[txid] to be removed, but it still exists") + } + }) + }) + } +} + +func Test_endpoint_handlePongConnLocked(t *testing.T) { + goodLatency := 50 * time.Millisecond + badLatency := 100 * time.Millisecond + expired := -1 * time.Hour + valid := 1 * time.Hour + directAddrA := epAddr{ap: netip.MustParseAddrPort("192.0.2.1:7")} + directAddrB := epAddr{ap: netip.MustParseAddrPort("192.0.2.2:8")} + derpAddr := epAddr{ap: netip.AddrPortFrom(tailcfg.DerpMagicIPAddr, 0)} + + for _, tc := range []struct { + name string + bestAddr addrQuality + trustBestAddrUntil time.Duration + pongFrom epAddr + pongLatency time.Duration + wantBestAddr epAddr + }{ + { + name: "better-latency-trust-valid", + bestAddr: addrQuality{epAddr: directAddrA, latency: badLatency}, + trustBestAddrUntil: valid, + pongFrom: directAddrB, + pongLatency: goodLatency, + wantBestAddr: directAddrB, + }, + { + name: "worse-latency-trust-valid", + bestAddr: addrQuality{epAddr: directAddrA, latency: goodLatency}, + trustBestAddrUntil: valid, + pongFrom: directAddrB, + pongLatency: badLatency, + wantBestAddr: directAddrA, + }, + { + name: "worse-latency-trust-expired", + bestAddr: addrQuality{epAddr: directAddrA, latency: goodLatency}, + trustBestAddrUntil: expired, + pongFrom: directAddrB, + pongLatency: badLatency, + wantBestAddr: directAddrB, + }, + { + name: "same-path-trust-expired", + bestAddr: addrQuality{epAddr: directAddrA, latency: badLatency}, + trustBestAddrUntil: expired, + pongFrom: directAddrA, + pongLatency: goodLatency, // updated latency + wantBestAddr: directAddrA, + }, + { + name: "derp-pong-trust-expired", + bestAddr: addrQuality{epAddr: directAddrA, latency: badLatency}, + trustBestAddrUntil: expired, + pongFrom: derpAddr, + pongLatency: goodLatency, + wantBestAddr: directAddrA, + }, + { + name: "better-latency-trust-expired", + bestAddr: addrQuality{epAddr: directAddrA, latency: badLatency}, + trustBestAddrUntil: expired, + pongFrom: directAddrB, + pongLatency: goodLatency, + wantBestAddr: directAddrB, + }, + } { + t.Run(tc.name, func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + now := mono.Now() // synctest to match this to the internal 'now' + pm := newPeerMap() + c := &Conn{ + logf: func(msg string, args ...any) {}, + peerMap: pm, + } + c.discoAtomic.Set(key.NewDisco()) + de := &endpoint{ + c: c, + bestAddr: tc.bestAddr, + bestAddrAt: now.Add(-5 * time.Minute), + trustBestAddrUntil: now.Add(tc.trustBestAddrUntil), + sentPing: make(map[stun.TxID]sentPing), + endpointState: make(map[netip.AddrPort]*endpointState), + debugUpdates: ringlog.New[EndpointChange](10), + } + txid := stun.NewTxID() + pong := &disco.Pong{ + TxID: txid, + Src: tc.pongFrom.ap, + } + timer := time.NewTimer(time.Hour) + timer.Stop() + de.sentPing[txid] = sentPing{ + to: tc.pongFrom, + at: now.Add(-tc.pongLatency), + timer: timer, + purpose: pingDiscovery, + } + if tc.pongFrom.ap.Addr() != tailcfg.DerpMagicIPAddr && !tc.pongFrom.vni.IsSet() { + de.endpointState[tc.pongFrom.ap] = &endpointState{} + } + di := &discoInfo{ + discoKey: key.NewDisco().Public(), + discoShort: "test", + } + + knownTxID := de.handlePongConnLocked(pong, di, tc.pongFrom) + if !knownTxID { + t.Errorf("expected knownTxID to be true, got false") + } + if de.bestAddr.epAddr != tc.wantBestAddr { + t.Errorf("expected bestAddr.epAddr to be %v, got: %v", tc.wantBestAddr, de.bestAddr.epAddr) + } + if tc.pongFrom == tc.bestAddr.epAddr && de.bestAddr.latency-tc.pongLatency > 0 { + t.Errorf("expected latency to be %v, got: %v", tc.pongLatency, de.bestAddr.latency) + } + if tc.pongFrom != derpAddr && de.trustBestAddrUntil.Before(now) { + t.Errorf("expected trustBestAddrUntil to be refreshed, but it's in the past: %v", de.trustBestAddrUntil) + } + if _, ok := de.sentPing[txid]; ok { + t.Errorf("expected sentPing[txid] to be removed, but it still exists") + } + }) + }) + } +} diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 169369f4bb472..9720f57cd0511 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -9,6 +9,7 @@ import ( "bufio" "bytes" "context" + "crypto/x509" "encoding/binary" "errors" "fmt" @@ -163,10 +164,11 @@ type Conn struct { derpActiveFunc func() idleFunc func() time.Duration // nil means unknown testOnlyPacketListener nettype.PacketListener - noteRecvActivity func(key.NodePublic) // or nil, see Options.NoteRecvActivity - netMon *netmon.Monitor // must be non-nil - health *health.Tracker // or nil - controlKnobs *controlknobs.Knobs // or nil + onDERPRecv func(int, key.NodePublic, []byte) bool // or nil, see Options.OnDERPRecv + netMon *netmon.Monitor // must be non-nil + health *health.Tracker // or nil + extraRootCAs *x509.CertPool // additional trusted root CAs; or nil + controlKnobs *controlknobs.Knobs // or nil // ================================================================ // No locking required to access these fields, either because @@ -180,6 +182,7 @@ type Conn struct { allocRelayEndpointPub *eventbus.Publisher[UDPRelayAllocReq] portUpdatePub *eventbus.Publisher[router.PortUpdate] tsmpDiscoKeyAvailablePub *eventbus.Publisher[NewDiscoKeyAvailable] + homeDERPChangedPub *eventbus.Publisher[HomeDERPChanged] // pconn4 and pconn6 are the underlying UDP sockets used to // send/receive packets for wireguard and other magicsock @@ -266,12 +269,6 @@ type Conn struct { // captureHook, if non-nil, is the pcap logging callback when capturing. captureHook syncs.AtomicValue[packet.CaptureCallback] - // hasPeerRelayServers is whether [relayManager] is configured with at least - // one peer relay server via [relayManager.handleRelayServersSet]. It exists - // to suppress calls into [relayManager] leading to wasted work involving - // channel operations and goroutine creation. - hasPeerRelayServers atomic.Bool - // discoAtomic is the current disco private and public keypair for this conn. discoAtomic discoAtomic @@ -358,18 +355,18 @@ type Conn struct { // magicsock could do with any complexity reduction it can get. netInfoLast *tailcfg.NetInfo - derpMap *tailcfg.DERPMap // nil (or zero regions/nodes) means DERP is disabled - self tailcfg.NodeView // from last SetNetworkMap - peers views.Slice[tailcfg.NodeView] // from last SetNetworkMap, sorted by Node.ID; Note: [netmap.NodeMutation]'s rx'd in UpdateNetmapDelta are never applied - filt *filter.Filter // from last SetFilter - relayClientEnabled bool // whether we can allocate UDP relay endpoints on UDP relay servers or receive CallMeMaybeVia messages from peers - lastFlags debugFlags // at time of last SetNetworkMap - privateKey key.NodePrivate // WireGuard private key for this node - everHadKey bool // whether we ever had a non-zero private key - myDerp int // nearest DERP region ID; 0 means none/unknown - homeless bool // if true, don't try to find & stay conneted to a DERP home (myDerp will stay 0) - derpStarted chan struct{} // closed on first connection to DERP; for tests & cleaner Close - activeDerp map[int]activeDerp // DERP regionID -> connection to a node in that region + derpMap *tailcfg.DERPMap // nil (or zero regions/nodes) means DERP is disabled + self tailcfg.NodeView // from last SetNetworkMap + peersByID map[tailcfg.NodeID]tailcfg.NodeView // current peer set, keyed by NodeID. Maintained by SetNetworkMap/UpsertPeer/RemovePeer. Note: per-field NodeMutation patches received in UpdateNetmapDelta are never applied to these snapshots. + filt *filter.Filter // from last SetFilter + relayClientEnabled bool // whether we can allocate UDP relay endpoints on UDP relay servers or receive CallMeMaybeVia messages from peers + lastFlags debugFlags // at time of last SetNetworkMap + privateKey key.NodePrivate // WireGuard private key for this node + everHadKey bool // whether we ever had a non-zero private key + myDerp int // nearest DERP region ID; 0 means none/unknown + homeless bool // if true, don't try to find & stay conneted to a DERP home (myDerp will stay 0) + derpStarted chan struct{} // closed on first connection to DERP; for tests & cleaner Close + activeDerp map[int]activeDerp // DERP regionID -> connection to a node in that region prevDerp map[int]*syncs.WaitGroupChan // derpRoute contains optional alternate routes to use as an @@ -459,19 +456,6 @@ type Options struct { // Only used by tests. TestOnlyPacketListener nettype.PacketListener - // NoteRecvActivity, if provided, is a func for magicsock to call - // whenever it receives a packet from a a peer if it's been more - // than ~10 seconds since the last one. (10 seconds is somewhat - // arbitrary; the sole user, lazy WireGuard configuration, - // just doesn't need or want it called on - // every packet, just every minute or two for WireGuard timeouts, - // and 10 seconds seems like a good trade-off between often enough - // and not too often.) - // The provided func is likely to call back into - // Conn.ParseEndpoint, which acquires Conn.mu. As such, you should - // not hold Conn.mu while calling it. - NoteRecvActivity func(key.NodePublic) - // NetMon is the network monitor to use. // It must be non-nil. NetMon *netmon.Monitor @@ -480,6 +464,10 @@ type Options struct { // report errors and warnings to. HealthTracker *health.Tracker + // ExtraRootCAs, if non-nil, specifies additional trusted root CAs + // for TLS connections to DERP servers. + ExtraRootCAs *x509.CertPool + // Metrics specifies the metrics registry to record metrics to. Metrics *usermetric.Registry @@ -495,6 +483,20 @@ type Options struct { // DisablePortMapper, if true, disables the portmapper. // This is primarily useful in tests. DisablePortMapper bool + + // ForceDiscoKey, if non-zero, forces the use of a specific disco + // private key. This should only be used for special cases and + // experiments, not for production. The recommended normal path is to + // leave it zero, in which case a new disco key is generated per + // Tailscale start and kept only in memory. + ForceDiscoKey key.DiscoPrivate + + // OnDERPRecv, if non-nil, is called for every non-disco packet + // received from DERP before the peer map lookup. If it returns + // true, the packet is considered handled and is not passed to + // WireGuard. The pkt slice is borrowed and must be copied if + // the callee needs to retain it. + OnDERPRecv func(regionID int, src key.NodePublic, pkt []byte) bool } func (o *Options) logf() logger.Logf { @@ -622,6 +624,9 @@ func NewConn(opts Options) (*Conn, error) { } c := newConn(opts.logf()) + if !opts.ForceDiscoKey.IsZero() { + c.discoAtomic.Set(opts.ForceDiscoKey) + } c.eventBus = opts.EventBus c.port.Store(uint32(opts.Port)) c.controlKnobs = opts.ControlKnobs @@ -629,7 +634,7 @@ func NewConn(opts Options) (*Conn, error) { c.derpActiveFunc = opts.derpActiveFunc() c.idleFunc = opts.IdleFunc c.testOnlyPacketListener = opts.TestOnlyPacketListener - c.noteRecvActivity = opts.NoteRecvActivity + c.onDERPRecv = opts.OnDERPRecv // Set up publishers and subscribers. Subscribe calls must return before // NewConn otherwise published events can be missed. @@ -638,6 +643,7 @@ func NewConn(opts Options) (*Conn, error) { c.allocRelayEndpointPub = eventbus.Publish[UDPRelayAllocReq](ec) c.portUpdatePub = eventbus.Publish[router.PortUpdate](ec) c.tsmpDiscoKeyAvailablePub = eventbus.Publish[NewDiscoKeyAvailable](ec) + c.homeDERPChangedPub = eventbus.Publish[HomeDERPChanged](ec) eventbus.SubscribeFunc(ec, c.onPortMapChanged) eventbus.SubscribeFunc(ec, c.onUDPRelayAllocResp) @@ -667,6 +673,7 @@ func NewConn(opts Options) (*Conn, error) { c.netMon = opts.NetMon c.health = opts.HealthTracker + c.extraRootCAs = opts.ExtraRootCAs c.getPeerByKey = opts.PeerByKeyFunc if err := c.rebind(keepCurrentPort); err != nil { @@ -1036,7 +1043,7 @@ func (c *Conn) updateNetInfo(ctx context.Context) (*netcheck.Report, error) { ni.OSHasIPv6.Set(report.OSHasIPv6) ni.WorkingUDP.Set(report.UDP) ni.WorkingICMPv4.Set(report.ICMPv4) - ni.PreferredDERP = c.maybeSetNearestDERP(report) + ni.PreferredDERP = c.maybeSetNearestDERP(report, false) ni.FirewallMode = hostinfo.FirewallMode() c.callNetInfoCallback(ni) @@ -1201,7 +1208,7 @@ func (c *Conn) RotateDiscoKey() { connCtx := c.connCtx for _, endpoint := range c.peerMap.byEpAddr { endpoint.ep.mu.Lock() - endpoint.ep.sentDiscoKeyAdvertisement = false + endpoint.ep.lastDiscoKeyAdvertisement = 0 endpoint.ep.mu.Unlock() } c.mu.Unlock() @@ -1419,7 +1426,18 @@ func (c *Conn) LocalPort() uint16 { var errNetworkDown = errors.New("magicsock: network down") -func (c *Conn) networkDown() bool { return !c.networkUp.Load() } +// This allows tests to pass when the user's machine is offline, but allows us +// to still test network-down behaviour when desired. +var checkNetworkDownDuringTests = false + +func (c *Conn) networkDown() bool { + // For tests, always assume the network is up unless we're explicitly + // testing this behaviour. + if envknob.AssumeNetworkUp() || (testenv.InTest() && !checkNetworkDownDuringTests) { + return false + } + return !c.networkUp.Load() +} // Send implements conn.Bind. // @@ -1474,8 +1492,7 @@ func (c *Conn) sendUDPBatch(addr epAddr, buffs [][]byte, offset int) (sent bool, err = c.pconn4.WriteWireGuardBatchTo(buffs, addr, offset) } if err != nil { - var errGSO neterror.ErrUDPGSODisabled - if errors.As(err, &errGSO) { + if errGSO, ok := errors.AsType[neterror.ErrUDPGSODisabled](err); ok { c.logf("magicsock: %s", errGSO.Error()) err = errGSO.RetryErr } else { @@ -2394,24 +2411,12 @@ func (c *Conn) handleDiscoMessage(msg []byte, src epAddr, shouldBeRelayHandshake if c.filt == nil { return } - // Binary search of peers is O(log n) while c.mu is held. - // TODO: We might be able to use ep.nodeAddr instead of all addresses, - // or we might be able to release c.mu before doing this work. Keep it - // simple and slow for now. c.peers.AsSlice is a copy. We may need to - // write our own binary search for a [views.Slice]. - peerI, ok := slices.BinarySearchFunc(c.peers.AsSlice(), ep.nodeID, func(peer tailcfg.NodeView, target tailcfg.NodeID) int { - if peer.ID() < target { - return -1 - } else if peer.ID() > target { - return 1 - } - return 0 - }) + peer, ok := c.peersByID[ep.nodeID] if !ok { // unexpected return } - if !nodeHasCap(c.filt, c.peers.At(peerI), c.self, tailcfg.PeerCapabilityRelay) { + if !nodeHasCap(c.filt, peer, c.self, tailcfg.PeerCapabilityRelay) { return } // [Conn.mu] must not be held while publishing, or [Conn.onUDPRelayAllocResp] @@ -2748,18 +2753,6 @@ func (c *Conn) UpdatePeers(newPeers set.Set[key.NodePublic]) { } } -func nodesEqual(x, y views.Slice[tailcfg.NodeView]) bool { - if x.Len() != y.Len() { - return false - } - for i := range x.Len() { - if !x.At(i).Equal(y.At(i)) { - return false - } - } - return true -} - // debugRingBufferSize returns a maximum size for our set of endpoint ring // buffers by assuming that a single large update is ~500 bytes, and that we // want to not use more than 1MiB of memory on phones / 4MiB on other devices. @@ -2847,7 +2840,7 @@ func (c *Conn) SetFilter(f *filter.Filter) { c.mu.Lock() c.filt = f self := c.self - peers := c.peers + peers := c.peerSnapshotLocked() relayClientEnabled := c.relayClientEnabled c.mu.Unlock() // release c.mu before potentially calling c.updateRelayServersSet which is O(m * n) @@ -2861,11 +2854,26 @@ func (c *Conn) SetFilter(f *filter.Filter) { c.updateRelayServersSet(f, self, peers) } +// peerSnapshotLocked returns a freshly-allocated slice of the current peers. +// It's used by callers that need to pass peer state to an O(m * n) callee +// (like [Conn.updateRelayServersSet]) after releasing c.mu. c.mu must be held. +func (c *Conn) peerSnapshotLocked() []tailcfg.NodeView { + if len(c.peersByID) == 0 { + return nil + } + out := make([]tailcfg.NodeView, 0, len(c.peersByID)) + for _, p := range c.peersByID { + out = append(out, p) + } + return out +} + // updateRelayServersSet iterates all peers and self, evaluating filt for each // one in order to determine which are relay server candidates. filt, self, and // peers are passed as args (vs c.mu-guarded fields) to enable callers to // release c.mu before calling as this is O(m * n) (we iterate all cap rules 'm' -// in filt for every peer 'n'). +// in filt for every peer 'n'). peers must be a snapshot owned by the caller; +// this function does not retain it after return. // // Calls to updateRelayServersSet must never run concurrent to // [endpoint.setDERPHome], otherwise [candidatePeerRelay] DERP home changes may @@ -2877,9 +2885,9 @@ func (c *Conn) SetFilter(f *filter.Filter) { // them. // 2. Moving this work upstream into [nodeBackend] or similar, and publishing // the computed result over the eventbus instead. -func (c *Conn) updateRelayServersSet(filt *filter.Filter, self tailcfg.NodeView, peers views.Slice[tailcfg.NodeView]) { +func (c *Conn) updateRelayServersSet(filt *filter.Filter, self tailcfg.NodeView, peers []tailcfg.NodeView) { relayServers := make(set.Set[candidatePeerRelay]) - nodes := append(peers.AsSlice(), self) + nodes := append(peers, self) for _, maybeCandidate := range nodes { if maybeCandidate.ID() != self.ID() && !capVerIsRelayCapable(maybeCandidate.Cap()) { // If maybeCandidate's [tailcfg.CapabilityVersion] is not relay-capable, @@ -2897,12 +2905,9 @@ func (c *Conn) updateRelayServersSet(filt *filter.Filter, self tailcfg.NodeView, derpHomeRegionID: uint16(maybeCandidate.HomeDERP()), }) } + // [relayManager]'s run loop updates [relayManager.hasPeerRelayServers] + // to reflect the new server count. c.relayManager.handleRelayServersSet(relayServers) - if len(relayServers) > 0 { - c.hasPeerRelayServers.Store(true) - } else { - c.hasPeerRelayServers.Store(false) - } } // nodeHasCap returns true if src has cap on dst, otherwise it returns false. @@ -2954,6 +2959,12 @@ func (c *candidatePeerRelay) isValid() bool { // magicsock has the current state before subsequent operations proceed. // // self may be invalid if there's no network map. +// +// SetNetworkMap takes the full peer list and walks all of it. For incremental +// updates where only a single peer changes, prefer the O(1) [Conn.UpsertPeer] +// and [Conn.RemovePeer] methods. SetNetworkMap remains the right call for the +// initial netmap and for changes to self or to global state (filter, DERP, +// etc.) that aren't covered by the per-peer methods. func (c *Conn) SetNetworkMap(self tailcfg.NodeView, peers []tailcfg.NodeView) { peersChanged := c.updateNodes(self, peers) @@ -2966,7 +2977,7 @@ func (c *Conn) SetNetworkMap(self tailcfg.NodeView, peers []tailcfg.NodeView) { c.relayClientEnabled = relayClientEnabled filt := c.filt selfView := c.self - peersView := c.peers + peersSnap := c.peerSnapshotLocked() isClosed := c.closed c.mu.Unlock() // release c.mu before potentially calling c.updateRelayServersSet which is O(m * n) @@ -2976,16 +2987,16 @@ func (c *Conn) SetNetworkMap(self tailcfg.NodeView, peers []tailcfg.NodeView) { if peersChanged || relayClientChanged { if !relayClientEnabled { + // [relayManager]'s run loop updates [relayManager.hasPeerRelayServers]. c.relayManager.handleRelayServersSet(nil) - c.hasPeerRelayServers.Store(false) } else { - c.updateRelayServersSet(filt, selfView, peersView) + c.updateRelayServersSet(filt, selfView, peersSnap) } } } // updateNodes updates [Conn] to reflect the given self node and peers. -// It reports whether the peers were changed from before. +// It reports whether the peer set (membership or any field) changed. func (c *Conn) updateNodes(self tailcfg.NodeView, peers []tailcfg.NodeView) (peersChanged bool) { c.mu.Lock() defer c.mu.Unlock() @@ -2994,13 +3005,9 @@ func (c *Conn) updateNodes(self tailcfg.NodeView, peers []tailcfg.NodeView) (pee return false } - priorPeers := c.peers metricNumPeers.Set(int64(len(peers))) - // Update c.self & c.peers regardless, before the following early return. c.self = self - curPeers := views.SliceOf(peers) - c.peers = curPeers // [debugFlags] are mutable in [Conn.SetSilentDisco] & // [Conn.SetProbeUDPLifetime]. These setters are passed [controlknobs.Knobs] @@ -3013,161 +3020,272 @@ func (c *Conn) updateNodes(self tailcfg.NodeView, peers []tailcfg.NodeView) (pee // TODO: mutate [debugFlags] here instead of in various [Conn] setters. flags := c.debugFlagsLocked() - peersChanged = !nodesEqual(priorPeers, curPeers) - if !peersChanged && c.lastFlags == flags { - // The rest of this function is all adjusting state for peers that have - // changed. But if the set of peers is equal and the debug flags (for - // silent disco and probe UDP lifetime) haven't changed, there is no - // need to do anything else. - return + // Fast path: if the peer set and every peer's NodeView are unchanged, + // and flags are unchanged, skip all further work. + if c.lastFlags == flags && len(peers) == len(c.peersByID) { + allSame := true + for _, n := range peers { + if prev, ok := c.peersByID[n.ID()]; !ok || !prev.Equal(n) { + allSame = false + break + } + } + if allSame { + return false + } } c.lastFlags = flags - c.logf("[v1] magicsock: got updated network map; %d peers", len(peers)) entriesPerBuffer := debugRingBufferSize(len(peers)) - // Try a pass of just upserting nodes and creating missing - // endpoints. If the set of nodes is the same, this is an - // efficient alloc-free update. If the set of nodes is different, - // we'll fall through to the next pass, which allocates but can - // handle full set updates. + // Build the new peer map while upserting each peer. + newPeers := make(map[tailcfg.NodeID]tailcfg.NodeView, len(peers)) for _, n := range peers { - if n.ID() == 0 { - devPanicf("node with zero ID") - continue - } - if n.Key().IsZero() { - devPanicf("node with zero key") - continue - } - ep, ok := c.peerMap.endpointForNodeID(n.ID()) - if ok && ep.publicKey != n.Key() { - // The node rotated public keys. Delete the old endpoint and create - // it anew. - c.peerMap.deleteEndpoint(ep) - ok = false + newPeers[n.ID()] = n + c.upsertPeerLocked(n, flags, entriesPerBuffer) + } + if len(newPeers) != len(peers) { + // Duplicate NodeIDs in the input shouldn't happen, but log if so. + c.logf("[unexpected] magicsock.updateNodes: %d peers input but %d unique IDs", len(peers), len(newPeers)) + } + c.peersByID = newPeers + + // If the upsert pass left stale endpoints in peerMap (peers removed + // relative to before), clean them up. + if c.peerMap.nodeCount() != len(newPeers) { + keep := set.Set[key.NodePublic]{} + for _, n := range newPeers { + keep.Add(n.Key()) } - if ok { - // At this point we're modifying an existing endpoint (ep) whose - // public key and nodeID match n. Its other fields (such as disco - // key or endpoints) might've changed. - - if n.DiscoKey().IsZero() && !n.IsWireGuardOnly() { - // Discokey transitioned from non-zero to zero? This should not - // happen in the wild, however it could mean: - // 1. A node was downgraded from post 0.100 to pre 0.100. - // 2. A Tailscale node key was extracted and used on a - // non-Tailscale node (should not enter here due to the - // IsWireGuardOnly check) - // 3. The server is misbehaving. + c.peerMap.forEachEndpoint(func(ep *endpoint) { + if !keep.Contains(ep.publicKey) { c.peerMap.deleteEndpoint(ep) - continue - } - var oldDiscoKey key.DiscoPublic - if epDisco := ep.disco.Load(); epDisco != nil { - oldDiscoKey = epDisco.key } - ep.updateFromNode(n, flags.heartbeatDisabled, flags.probeUDPLifetimeOn) - c.peerMap.upsertEndpoint(ep, oldDiscoKey) // maybe update discokey mappings in peerMap - continue - } + }) + } - if ep, ok := c.peerMap.endpointForNodeKey(n.Key()); ok { - // At this point n.Key() should be for a key we've never seen before. If - // ok was true above, it was an update to an existing matching key and - // we don't get this far. If ok was false above, that means it's a key - // that differs from the one the NodeID had. But double check. - if ep.nodeID != n.ID() { - // Server error. This is known to be a particular issue for Mullvad - // nodes (http://go/corp/27300), so log a distinct error for the - // Mullvad and non-Mullvad cases. The error will be logged either way, - // so an approximate heuristic is fine. - // - // When #27300 is fixed, we can delete this branch and log the same - // panic for any public key moving. - if strings.HasSuffix(n.Name(), ".mullvad.ts.net.") { - devPanicf("public key moved between Mullvad nodeIDs (old=%v new=%v, key=%s); see http://go/corp/27300", ep.nodeID, n.ID(), n.Key().String()) - } else { - devPanicf("public key moved between nodeIDs (old=%v new=%v, key=%s)", ep.nodeID, n.ID(), n.Key().String()) - } - } else { - // Internal data structures out of sync. - devPanicf("public key found in peerMap but not by nodeID") - } - continue - } - if n.DiscoKey().IsZero() && !n.IsWireGuardOnly() { - // Ancient pre-0.100 node, which does not have a disco key. - // No longer supported. - continue + // discokeys might have changed above. Discard unused info. + for dk := range c.discoInfo { + if !c.peerMap.knownPeerDiscoKey(dk) { + delete(c.discoInfo, dk) } + } - ep = &endpoint{ - c: c, - nodeID: n.ID(), - publicKey: n.Key(), - publicKeyHex: n.Key().UntypedHexString(), - sentPing: map[stun.TxID]sentPing{}, - endpointState: map[netip.AddrPort]*endpointState{}, - heartbeatDisabled: flags.heartbeatDisabled, - isWireguardOnly: n.IsWireGuardOnly(), - } - switch runtime.GOOS { - case "ios", "android": - // Omit, to save memory. Prior to 2024-03-20 we used to limit it to - // ~1MB on mobile but we never used the data so the memory was just - // wasted. - default: - ep.debugUpdates = ringlog.New[EndpointChange](entriesPerBuffer) + return true +} + +// upsertPeerLocked upserts a single peer's endpoint in c.peerMap. It is the +// per-peer body shared by [Conn.SetNetworkMap]'s upsert pass and by the +// efficient per-peer [Conn.UpsertPeer] path. +// +// c.mu must be held. +func (c *Conn) upsertPeerLocked(n tailcfg.NodeView, flags debugFlags, entriesPerBuffer int) { + if n.ID() == 0 { + devPanicf("node with zero ID") + return + } + if n.Key().IsZero() { + devPanicf("node with zero key") + return + } + ep, ok := c.peerMap.endpointForNodeID(n.ID()) + if ok && ep.publicKey != n.Key() { + // The node rotated public keys. Delete the old endpoint and create + // it anew. + c.peerMap.deleteEndpoint(ep) + ok = false + } + if ok { + // At this point we're modifying an existing endpoint (ep) whose + // public key and nodeID match n. Its other fields (such as disco + // key or endpoints) might've changed. + + if n.DiscoKey().IsZero() && !n.IsWireGuardOnly() { + // Discokey transitioned from non-zero to zero? This should not + // happen in the wild, however it could mean: + // 1. A node was downgraded from post 0.100 to pre 0.100. + // 2. A Tailscale node key was extracted and used on a + // non-Tailscale node (should not enter here due to the + // IsWireGuardOnly check) + // 3. The server is misbehaving. + c.peerMap.deleteEndpoint(ep) + return } - if n.Addresses().Len() > 0 { - ep.nodeAddr = n.Addresses().At(0).Addr() + var oldDiscoKey key.DiscoPublic + if epDisco := ep.disco.Load(); epDisco != nil { + oldDiscoKey = epDisco.key } - ep.initFakeUDPAddr() - if n.DiscoKey().IsZero() { - ep.disco.Store(nil) + ep.updateFromNode(n, flags.heartbeatDisabled, flags.probeUDPLifetimeOn) + c.peerMap.upsertEndpoint(ep, oldDiscoKey) // maybe update discokey mappings in peerMap + return + } + + if ep, ok := c.peerMap.endpointForNodeKey(n.Key()); ok { + // At this point n.Key() should be for a key we've never seen before. If + // ok was true above, it was an update to an existing matching key and + // we don't get this far. If ok was false above, that means it's a key + // that differs from the one the NodeID had. But double check. + if ep.nodeID != n.ID() { + // Server error. This is known to be a particular issue for Mullvad + // nodes (http://go/corp/27300), so log a distinct error for the + // Mullvad and non-Mullvad cases. The error will be logged either way, + // so an approximate heuristic is fine. + // + // When #27300 is fixed, we can delete this branch and log the same + // panic for any public key moving. + if strings.HasSuffix(n.Name(), ".mullvad.ts.net.") { + devPanicf("public key moved between Mullvad nodeIDs (old=%v new=%v, key=%s); see http://go/corp/27300", ep.nodeID, n.ID(), n.Key().String()) + } else { + devPanicf("public key moved between nodeIDs (old=%v new=%v, key=%s)", ep.nodeID, n.ID(), n.Key().String()) + } } else { - ep.disco.Store(&endpointDisco{ - key: n.DiscoKey(), - short: n.DiscoKey().ShortString(), - }) + // Internal data structures out of sync. + devPanicf("public key found in peerMap but not by nodeID") } + return + } + if n.DiscoKey().IsZero() && !n.IsWireGuardOnly() { + // Ancient pre-0.100 node, which does not have a disco key. + // No longer supported. + return + } - if debugPeerMap() { - c.logEndpointCreated(n) - } + ep = &endpoint{ + c: c, + nodeID: n.ID(), + publicKey: n.Key(), + publicKeyHex: n.Key().UntypedHexString(), + sentPing: map[stun.TxID]sentPing{}, + endpointState: map[netip.AddrPort]*endpointState{}, + heartbeatDisabled: flags.heartbeatDisabled, + isWireguardOnly: n.IsWireGuardOnly(), + } + switch runtime.GOOS { + case "ios", "android": + // Omit, to save memory. Prior to 2024-03-20 we used to limit it to + // ~1MB on mobile but we never used the data so the memory was just + // wasted. + default: + ep.debugUpdates = ringlog.New[EndpointChange](entriesPerBuffer) + } + if n.Addresses().Len() > 0 { + ep.nodeAddr = n.Addresses().At(0).Addr() + } + ep.initFakeUDPAddr() + ep.updateDiscoKey(n.DiscoKey()) - ep.updateFromNode(n, flags.heartbeatDisabled, flags.probeUDPLifetimeOn) - c.peerMap.upsertEndpoint(ep, key.DiscoPublic{}) + if debugPeerMap() { + c.logEndpointCreated(n) } - // If the set of nodes changed since the last SetNetworkMap, the - // upsert loop just above made c.peerMap contain the union of the - // old and new peers - which will be larger than the set from the - // current netmap. If that happens, go through the allocful - // deletion path to clean up moribund nodes. - if c.peerMap.nodeCount() != len(peers) { - keep := set.Set[key.NodePublic]{} - for _, n := range peers { - keep.Add(n.Key()) - } - c.peerMap.forEachEndpoint(func(ep *endpoint) { - if !keep.Contains(ep.publicKey) { - c.peerMap.deleteEndpoint(ep) - } - }) + ep.updateFromNode(n, flags.heartbeatDisabled, flags.probeUDPLifetimeOn) + c.peerMap.upsertEndpoint(ep, key.DiscoPublic{}) +} + +// UpsertPeer adds or updates a single peer in c. It is the efficient +// O(1)-per-peer alternative to [Conn.SetNetworkMap] when a single peer was +// added or its fields changed. The caller is responsible for serializing +// UpsertPeer/RemovePeer/SetNetworkMap calls relative to one another. +// +// UpsertPeer updates the relay-server set incrementally (O(1)) when the +// upserted peer's relay candidacy changed, rather than rebuilding the +// whole set with [Conn.updateRelayServersSet]. +func (c *Conn) UpsertPeer(n tailcfg.NodeView) { + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return } + if n.ID() == 0 { + c.mu.Unlock() + devPanicf("UpsertPeer: node with zero ID") + return + } + flags := c.debugFlagsLocked() + c.peersByID[n.ID()] = n + c.upsertPeerLocked(n, flags, debugRingBufferSize(len(c.peersByID))) - // discokeys might have changed in the above. Discard unused info. - for dk := range c.discoInfo { - if !c.peerMap.knownPeerDiscoKey(dk) { - delete(c.discoInfo, dk) + var relayUpsert candidatePeerRelay + relayQualifies := false + if c.relayClientEnabled { + relayQualifies, relayUpsert = c.relayCandidateLocked(n) + } + relayClientEnabled := c.relayClientEnabled + c.mu.Unlock() + + if relayClientEnabled { + if relayQualifies { + c.relayManager.handleRelayServerUpsert(relayUpsert) + } else { + // The peer may have previously qualified; remove covers that + // case and is a no-op otherwise. + c.relayManager.handleRelayServerRemove(n.Key()) } } +} + +// RemovePeer removes a single peer from c. It is the efficient +// O(1)-per-peer alternative to [Conn.SetNetworkMap] when a single peer was +// removed. The caller is responsible for serializing UpsertPeer/RemovePeer/ +// SetNetworkMap calls relative to one another. +func (c *Conn) RemovePeer(nid tailcfg.NodeID) { + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return + } + prev, ok := c.peersByID[nid] + if !ok { + c.mu.Unlock() + return + } + delete(c.peersByID, nid) + if ep, ok := c.peerMap.endpointForNodeID(nid); ok { + c.peerMap.deleteEndpoint(ep) + } + + // If the peer we just removed held the only reference to its disco + // key, drop the now-orphaned c.discoInfo entry. No need to scan the + // whole map — only this peer's disco key can have become unreferenced + // by this single removal. + if dk := prev.DiscoKey(); !dk.IsZero() && !c.peerMap.knownPeerDiscoKey(dk) { + delete(c.discoInfo, dk) + } + + relayClientEnabled := c.relayClientEnabled + c.mu.Unlock() + + if relayClientEnabled { + // Tell the relay manager to drop the peer. The run loop no-ops + // this if the peer wasn't a relay server. + c.relayManager.handleRelayServerRemove(prev.Key()) + } +} - return peersChanged +// relayCandidateLocked reports whether peer p is eligible to be a relay +// server candidate for self, and if so returns the [candidatePeerRelay] +// that would be added to the relay-server set. c.mu must be held. +// +// It mirrors the per-peer predicate in [Conn.updateRelayServersSet]. +func (c *Conn) relayCandidateLocked(p tailcfg.NodeView) (ok bool, cp candidatePeerRelay) { + if !p.Valid() { + return false, candidatePeerRelay{} + } + // The cap-version gate in updateRelayServersSet only applies to peers + // (not self). This helper is only called for peers, so always check. + if !capVerIsRelayCapable(p.Cap()) { + return false, candidatePeerRelay{} + } + if !nodeHasCap(c.filt, p, c.self, tailcfg.PeerCapabilityRelayTarget) { + return false, candidatePeerRelay{} + } + return true, candidatePeerRelay{ + nodeKey: p.Key(), + discoKey: p.DiscoKey(), + derpHomeRegionID: uint16(p.HomeDERP()), + } } func devPanicf(format string, a ...any) { @@ -4137,16 +4255,10 @@ var _ conn.Endpoint = (*lazyEndpoint)(nil) // InitiationMessagePublicKey implements [conn.InitiationAwareEndpoint]. // wireguard-go calls us here if we passed it a [*lazyEndpoint] for an -// initiation message, for which it might not have the relevant peer configured, -// enabling us to just-in-time configure it and note its activity via -// [*endpoint.noteRecvActivity], before it performs peer lookup and attempts -// decryption. +// initiation message, for which it might not have the relevant peer configured. +// Wireguard-go's PeerLookupFunc handles on-demand peer creation. // -// Reception of all other WireGuard message types implies pre-existing knowledge -// of the peer by wireguard-go for it to do useful work. See -// [userspaceEngine.maybeReconfigWireguardLocked] & -// [userspaceEngine.noteRecvActivity] for more details around just-in-time -// wireguard-go peer (de)configuration. +// We still update endpoint activity tracking for bestAddr management. func (le *lazyEndpoint) InitiationMessagePublicKey(peerPublicKey [32]byte) { pubKey := key.NodePublicFromRaw32(mem.B(peerPublicKey[:])) if le.maybeEP != nil && pubKey.Compare(le.maybeEP.publicKey) == 0 { @@ -4154,9 +4266,6 @@ func (le *lazyEndpoint) InitiationMessagePublicKey(peerPublicKey [32]byte) { } le.c.mu.Lock() ep, ok := le.c.peerMap.endpointForNodeKey(pubKey) - // [Conn.mu] must not be held while [Conn.noteRecvActivity] is called, which - // [endpoint.noteRecvActivity] can end up calling. See - // [Options.NoteRecvActivity] docs. le.c.mu.Unlock() if !ok { return @@ -4164,11 +4273,6 @@ func (le *lazyEndpoint) InitiationMessagePublicKey(peerPublicKey [32]byte) { now := mono.Now() ep.lastRecvUDPAny.StoreAtomic(now) ep.noteRecvActivity(le.src, now) - // [ep.noteRecvActivity] may end up JIT configuring the peer, but we don't - // update [peerMap] as wireguard-go hasn't decrypted the initiation - // message yet. wireguard-go will call us below in [lazyEndpoint.FromPeer] - // if it successfully decrypts the message, at which point it's safe to - // insert le.src into the [peerMap] for ep. } func (le *lazyEndpoint) ClearSrc() {} @@ -4266,13 +4370,11 @@ func (c *Conn) HandleDiscoKeyAdvertisement(node tailcfg.NodeView, update packet. // If the key did not change, count it and return. if oldDiscoKey.Compare(discoKey) == 0 { metricTSMPDiscoKeyAdvertisementUnchanged.Add(1) + c.logf("magicsock: disco key did not change for node %v", nodeKey.ShortString()) return } c.discoInfoForKnownPeerLocked(discoKey) - ep.disco.Store(&endpointDisco{ - key: discoKey, - short: discoKey.ShortString(), - }) + ep.updateDiscoKey(discoKey) c.peerMap.upsertEndpoint(ep, oldDiscoKey) c.logf("magicsock: updated disco key for peer %v to %v", nodeKey.ShortString(), discoKey.ShortString()) metricTSMPDiscoKeyAdvertisementApplied.Add(1) @@ -4301,25 +4403,27 @@ type NewDiscoKeyAvailable struct { // maybeSendTSMPDiscoAdvert conditionally emits an event indicating that we // should send our DiscoKey to the first node address of the magicksock endpoint. -// The event is only emitted if we have not yet contacted that endpoint since -// the DiscoKey changed. -// -// This condition is most likely met only once per endpoint, after the start of -// tailscaled, but not until we contact the endpoint for the first time. +// The event is only emitted if we are not already communicating directly and +// more than 60 seconds has passed since the last DiscoKey was sent. // // We do not need the Conn to be locked, but the endpoint should be. func (c *Conn) maybeSendTSMPDiscoAdvert(de *endpoint) { - if !buildfeatures.HasCacheNetMap || !envknob.Bool("TS_USE_CACHED_NETMAP") { + if !buildfeatures.HasCacheNetMap || !envknob.BoolDefaultTrue("TS_USE_CACHED_NETMAP") { return } de.mu.Lock() defer de.mu.Unlock() - if !de.sentDiscoKeyAdvertisement { - de.sentDiscoKeyAdvertisement = true - c.tsmpDiscoKeyAvailablePub.Publish(NewDiscoKeyAvailable{ - NodeFirstAddr: de.nodeAddr, - NodeID: de.nodeID, - }) + + now := mono.Now() + if now.Sub(de.lastDiscoKeyAdvertisement) <= discoKeyAdvertisementInterval || + (!de.lastDiscoKeyAdvertisement.IsZero() && de.bestAddr.isDirect()) { + return } + + de.lastDiscoKeyAdvertisement = now + c.tsmpDiscoKeyAvailablePub.Publish(NewDiscoKeyAvailable{ + NodeFirstAddr: de.nodeAddr, + NodeID: de.nodeID, + }) } diff --git a/wgengine/magicsock/magicsock_linux_test.go b/wgengine/magicsock/magicsock_linux_test.go index b670fa6bab601..200979f815ab0 100644 --- a/wgengine/magicsock/magicsock_linux_test.go +++ b/wgengine/magicsock/magicsock_linux_test.go @@ -184,19 +184,19 @@ func TestBpfDiscardV4(t *testing.T) { accept bool }{ { - name: "base accepted datagram", + name: "base-accepted-datagram", replace: map[int]byte{}, accept: true, }, { - name: "more fragments", + name: "more-fragments", replace: map[int]byte{ 6: 0x20, }, accept: false, }, { - name: "some fragment", + name: "some-fragment", replace: map[int]byte{ 7: 0x01, }, diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 9d6cae87bdcc6..3552ecc191377 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -39,7 +39,6 @@ import ( "go4.org/mem" "golang.org/x/net/icmp" "golang.org/x/net/ipv4" - "tailscale.com/cmd/testwrapper/flakytest" "tailscale.com/control/controlknobs" "tailscale.com/derp/derpserver" "tailscale.com/disco" @@ -63,8 +62,6 @@ import ( "tailscale.com/types/netlogtype" "tailscale.com/types/netmap" "tailscale.com/types/nettype" - "tailscale.com/types/ptr" - "tailscale.com/types/views" "tailscale.com/util/cibuild" "tailscale.com/util/clientmetric" "tailscale.com/util/eventbus" @@ -245,6 +242,25 @@ func newMagicStackWithKey(t testing.TB, logf logger.Logf, ln nettype.PacketListe func (s *magicStack) Reconfig(cfg *wgcfg.Config) error { s.tsTun.SetWGConfig(cfg) s.wgLogger.SetPeers(cfg.Peers) + + // In production, LocalBackend installs a PeerByIPPacketFunc via + // Engine.SetPeerByIPPacketFunc. Tests that bypass LocalBackend need + // to install one here for outbound packet routing. + ipToPeer := make(map[netip.Addr]device.NoisePublicKey, len(cfg.Peers)) + for _, p := range cfg.Peers { + pk := p.PublicKey.Raw32() + for _, pfx := range p.AllowedIPs { + if pfx.IsSingleIP() { + ipToPeer[pfx.Addr()] = pk + } + } + } + s.dev.SetPeerByIPPacketFunc(func(_, dst netip.Addr, _ []byte) (device.NoisePublicKey, bool) { + pk, ok := ipToPeer[dst] + return pk, ok + }) + + s.dev.SetPrivateKey(key.NodePrivateAs[device.NoisePrivateKey](cfg.PrivateKey)) return wgcfg.ReconfigDevice(s.dev, cfg, s.conn.logf) } @@ -415,9 +431,11 @@ func TestNewConn(t *testing.T) { stunAddr, stunCleanupFn := stuntest.Serve(t) defer stunCleanupFn() - port := pickPort(t) + // Use port 0 to let the system assign a port, avoiding TOCTOU races + // from the previous pickPort approach which would close a socket and + // hope to rebind to the same port. conn, err := NewConn(Options{ - Port: port, + Port: 0, DisablePortMapper: true, EndpointsFunc: epFunc, Logf: t.Logf, @@ -429,6 +447,13 @@ func TestNewConn(t *testing.T) { t.Fatal(err) } defer conn.Close() + + // Get the actual port that was assigned + port := conn.LocalPort() + if port == 0 { + t.Fatal("LocalPort returned 0") + } + conn.SetDERPMap(stuntest.DERPMapOf(stunAddr.String())) conn.SetPrivateKey(key.NewNode()) @@ -464,16 +489,6 @@ collectEndpoints: } } -func pickPort(t testing.TB) uint16 { - t.Helper() - conn, err := net.ListenPacket("udp4", "127.0.0.1:0") - if err != nil { - t.Fatal(err) - } - defer conn.Close() - return uint16(conn.LocalAddr().(*net.UDPAddr).Port) -} - func TestPickDERPFallback(t *testing.T) { tstest.PanicOnLog() tstest.ResourceCheck(t) @@ -734,7 +749,6 @@ func (localhostListener) ListenPacket(ctx context.Context, network, address stri } func TestTwoDevicePing(t *testing.T) { - flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/11762") ln, ip := localhostListener{}, netaddr.IPv4(127, 0, 0, 1) n := &devices{ m1: ln, @@ -1192,15 +1206,19 @@ func testTwoDevicePing(t *testing.T, d *devices) { m2.conn.SetConnectionCounter(m2.counts.Add) checkStats := func(t *testing.T, m *magicStack, wantConns []netlogtype.Connection) { + t.Helper() defer m.counts.Reset() - counts := m.counts.Clone() - for _, conn := range wantConns { - if _, ok := counts[conn]; ok { - return + if err := tstest.WaitFor(5*time.Second, func() error { + counts := m.counts.Clone() + for _, conn := range wantConns { + if _, ok := counts[conn]; ok { + return nil + } } + return fmt.Errorf("missing any connection to %s from %s", wantConns, slicesx.MapKeys(counts)) + }); err != nil { + t.Error(err) } - t.Helper() - t.Errorf("missing any connection to %s from %s", wantConns, slicesx.MapKeys(counts)) } addrPort := netip.MustParseAddrPort @@ -1214,7 +1232,7 @@ func testTwoDevicePing(t *testing.T, d *devices) { } outerT := t - t.Run("ping 1.0.0.1", func(t *testing.T) { + t.Run("ping-1_0_0_1", func(t *testing.T) { setT(t) defer setT(outerT) ping1(t) @@ -1222,7 +1240,7 @@ func testTwoDevicePing(t *testing.T, d *devices) { checkStats(t, m2, m2Conns) }) - t.Run("ping 1.0.0.2", func(t *testing.T) { + t.Run("ping-1_0_0_2", func(t *testing.T) { setT(t) defer setT(outerT) ping2(t) @@ -1230,7 +1248,7 @@ func testTwoDevicePing(t *testing.T, d *devices) { checkStats(t, m2, m2Conns) }) - t.Run("ping 1.0.0.2 via SendPacket", func(t *testing.T) { + t.Run("ping-1_0_0_2-via-SendPacket", func(t *testing.T) { setT(t) defer setT(outerT) msg1to2 := tuntest.Ping(netip.MustParseAddr("1.0.0.2"), netip.MustParseAddr("1.0.0.1")) @@ -1248,7 +1266,7 @@ func testTwoDevicePing(t *testing.T, d *devices) { checkStats(t, m2, m2Conns) }) - t.Run("no-op dev1 reconfig", func(t *testing.T) { + t.Run("no-op-dev1-reconfig", func(t *testing.T) { setT(t) defer setT(outerT) if err := m1.Reconfig(m1cfg); err != nil { @@ -1262,93 +1280,165 @@ func testTwoDevicePing(t *testing.T, d *devices) { t.Run("compare-metrics-stats", func(t *testing.T) { setT(t) defer setT(outerT) - m1.conn.resetMetricsForTest() - m1.counts.Reset() - m2.conn.resetMetricsForTest() - m2.counts.Reset() - t.Logf("Metrics before: %s\n", m1.metrics.String()) + + // Snapshot both counting systems before pings rather than + // resetting them. Resetting two independent systems + // non-atomically left a window where background WireGuard + // keepalives could increment one system but not the other, + // causing flaky off-by-one mismatches. + physBefore1, metricBefore1 := snapshotCounts(m1) + physBefore2, metricBefore2 := snapshotCounts(m2) + ping1(t) ping2(t) - assertConnStatsAndUserMetricsEqual(t, m1) - assertConnStatsAndUserMetricsEqual(t, m2) - t.Logf("Metrics after: %s\n", m1.metrics.String()) + + assertConnStatDeltasMatchMetricDeltas(t, m1, physBefore1, metricBefore1) + assertConnStatDeltasMatchMetricDeltas(t, m2, physBefore2, metricBefore2) + assertGlobalMetricsMatchPerConn(t, m1, m2) }) } -func (c *Conn) resetMetricsForTest() { - c.metrics.inboundBytesIPv4Total.Set(0) - c.metrics.inboundPacketsIPv4Total.Set(0) - c.metrics.outboundBytesIPv4Total.Set(0) - c.metrics.outboundPacketsIPv4Total.Set(0) - c.metrics.inboundBytesIPv6Total.Set(0) - c.metrics.inboundPacketsIPv6Total.Set(0) - c.metrics.outboundBytesIPv6Total.Set(0) - c.metrics.outboundPacketsIPv6Total.Set(0) - c.metrics.inboundBytesDERPTotal.Set(0) - c.metrics.inboundPacketsDERPTotal.Set(0) - c.metrics.outboundBytesDERPTotal.Set(0) - c.metrics.outboundPacketsDERPTotal.Set(0) +// countSnapshot holds a point-in-time snapshot of packet/byte statistics, +// categorized by transport type (IPv4 vs DERP). +type countSnapshot struct { + ipv4RxBytes, ipv4TxBytes int64 + ipv4RxPackets, ipv4TxPackets int64 + derpRxBytes, derpTxBytes int64 + derpRxPackets, derpTxPackets int64 } -func assertConnStatsAndUserMetricsEqual(t *testing.T, ms *magicStack) { - physIPv4RxBytes := int64(0) - physIPv4TxBytes := int64(0) - physDERPRxBytes := int64(0) - physDERPTxBytes := int64(0) - physIPv4RxPackets := int64(0) - physIPv4TxPackets := int64(0) - physDERPRxPackets := int64(0) - physDERPTxPackets := int64(0) +// snapshotCounts captures the current physical connection counter values and +// user metrics for ms, returning them as separate snapshots. Reading both +// systems back-to-back (rather than resetting them non-atomically) avoids a +// race where background WireGuard keepalives could increment one system but +// not the other during a reset window. +func snapshotCounts(ms *magicStack) (phys, metric countSnapshot) { for conn, count := range ms.counts.Clone() { - t.Logf("physconn src: %s, dst: %s", conn.Src.String(), conn.Dst.String()) if conn.Dst.String() == "127.3.3.40:1" { - physDERPRxBytes += int64(count.RxBytes) - physDERPTxBytes += int64(count.TxBytes) - physDERPRxPackets += int64(count.RxPackets) - physDERPTxPackets += int64(count.TxPackets) + phys.derpRxBytes += int64(count.RxBytes) + phys.derpTxBytes += int64(count.TxBytes) + phys.derpRxPackets += int64(count.RxPackets) + phys.derpTxPackets += int64(count.TxPackets) } else { - physIPv4RxBytes += int64(count.RxBytes) - physIPv4TxBytes += int64(count.TxBytes) - physIPv4RxPackets += int64(count.RxPackets) - physIPv4TxPackets += int64(count.TxPackets) + phys.ipv4RxBytes += int64(count.RxBytes) + phys.ipv4TxBytes += int64(count.TxBytes) + phys.ipv4RxPackets += int64(count.RxPackets) + phys.ipv4TxPackets += int64(count.TxPackets) } } - ms.counts.Reset() + metric = countSnapshot{ + ipv4RxBytes: ms.conn.metrics.inboundBytesIPv4Total.Value(), + ipv4TxBytes: ms.conn.metrics.outboundBytesIPv4Total.Value(), + ipv4RxPackets: ms.conn.metrics.inboundPacketsIPv4Total.Value(), + ipv4TxPackets: ms.conn.metrics.outboundPacketsIPv4Total.Value(), + derpRxBytes: ms.conn.metrics.inboundBytesDERPTotal.Value(), + derpTxBytes: ms.conn.metrics.outboundBytesDERPTotal.Value(), + derpRxPackets: ms.conn.metrics.inboundPacketsDERPTotal.Value(), + derpTxPackets: ms.conn.metrics.outboundPacketsDERPTotal.Value(), + } + return phys, metric +} - metricIPv4RxBytes := ms.conn.metrics.inboundBytesIPv4Total.Value() - metricIPv4RxPackets := ms.conn.metrics.inboundPacketsIPv4Total.Value() - metricIPv4TxBytes := ms.conn.metrics.outboundBytesIPv4Total.Value() - metricIPv4TxPackets := ms.conn.metrics.outboundPacketsIPv4Total.Value() +// assertConnStatDeltasMatchMetricDeltas checks that the changes in physical +// connection counters since physBefore match the changes in user metrics since +// metricBefore. Using deltas avoids a race from non-atomically resetting the +// two independent counting systems. +// +// As a safety net, a difference of exactly one packet (and the corresponding +// bytes) is tolerated, since a background WireGuard keepalive could still +// arrive in the narrow window between snapshotting the two systems. +func assertConnStatDeltasMatchMetricDeltas(t *testing.T, ms *magicStack, physBefore, metricBefore countSnapshot) { + t.Helper() + physAfter, metricAfter := snapshotCounts(ms) + + type stat struct { + name string + physDelta, metDelta int64 + isPackets bool // true for packet counts, false for byte counts + packetDeltaTolerated bool // set by packet check, used by byte check + } + + stats := []stat{ + {name: "IPv4RxPackets", physDelta: physAfter.ipv4RxPackets - physBefore.ipv4RxPackets, metDelta: metricAfter.ipv4RxPackets - metricBefore.ipv4RxPackets, isPackets: true}, + {name: "IPv4RxBytes", physDelta: physAfter.ipv4RxBytes - physBefore.ipv4RxBytes, metDelta: metricAfter.ipv4RxBytes - metricBefore.ipv4RxBytes}, + {name: "IPv4TxPackets", physDelta: physAfter.ipv4TxPackets - physBefore.ipv4TxPackets, metDelta: metricAfter.ipv4TxPackets - metricBefore.ipv4TxPackets, isPackets: true}, + {name: "IPv4TxBytes", physDelta: physAfter.ipv4TxBytes - physBefore.ipv4TxBytes, metDelta: metricAfter.ipv4TxBytes - metricBefore.ipv4TxBytes}, + {name: "DERPRxPackets", physDelta: physAfter.derpRxPackets - physBefore.derpRxPackets, metDelta: metricAfter.derpRxPackets - metricBefore.derpRxPackets, isPackets: true}, + {name: "DERPRxBytes", physDelta: physAfter.derpRxBytes - physBefore.derpRxBytes, metDelta: metricAfter.derpRxBytes - metricBefore.derpRxBytes}, + {name: "DERPTxPackets", physDelta: physAfter.derpTxPackets - physBefore.derpTxPackets, metDelta: metricAfter.derpTxPackets - metricBefore.derpTxPackets, isPackets: true}, + {name: "DERPTxBytes", physDelta: physAfter.derpTxBytes - physBefore.derpTxBytes, metDelta: metricAfter.derpTxBytes - metricBefore.derpTxBytes}, + } + + // First pass: check packet counts, tolerating Âą1 from stray keepalives. + for i := range stats { + s := &stats[i] + if !s.isPackets { + continue + } + if s.physDelta == s.metDelta { + continue + } + diff := s.physDelta - s.metDelta + if diff < 0 { + diff = -diff + } + if diff <= 1 { + s.packetDeltaTolerated = true + t.Logf("%s: physical delta=%d, metric delta=%d (off by 1, likely background WireGuard keepalive)", s.name, s.physDelta, s.metDelta) + continue + } + t.Errorf("%s: physical delta=%d, metric delta=%d", s.name, s.physDelta, s.metDelta) + } - metricDERPRxBytes := ms.conn.metrics.inboundBytesDERPTotal.Value() - metricDERPRxPackets := ms.conn.metrics.inboundPacketsDERPTotal.Value() - metricDERPTxBytes := ms.conn.metrics.outboundBytesDERPTotal.Value() - metricDERPTxPackets := ms.conn.metrics.outboundPacketsDERPTotal.Value() + // Second pass: check byte counts; tolerate mismatches when the + // corresponding packet count was already tolerated. + for i := range stats { + s := &stats[i] + if s.isPackets { + continue + } + if s.physDelta == s.metDelta { + continue + } + // The preceding entry in the slice is always the corresponding packet stat. + if stats[i-1].packetDeltaTolerated { + t.Logf("%s: physical delta=%d, metric delta=%d (within single-packet tolerance)", s.name, s.physDelta, s.metDelta) + continue + } + t.Errorf("%s: physical delta=%d, metric delta=%d", s.name, s.physDelta, s.metDelta) + } +} +// assertGlobalMetricsMatchPerConn validates that the global clientmetric +// AggregateCounters match the sum of per-conn user metrics from both magicsock +// instances. This tests the metric registration wiring rather than assuming +// symmetric traffic between the two instances. +func assertGlobalMetricsMatchPerConn(t *testing.T, m1, m2 *magicStack) { + t.Helper() c := qt.New(t) - c.Assert(physDERPRxBytes, qt.Equals, metricDERPRxBytes) - c.Assert(physDERPTxBytes, qt.Equals, metricDERPTxBytes) - c.Assert(physIPv4RxBytes, qt.Equals, metricIPv4RxBytes) - c.Assert(physIPv4TxBytes, qt.Equals, metricIPv4TxBytes) - c.Assert(physDERPRxPackets, qt.Equals, metricDERPRxPackets) - c.Assert(physDERPTxPackets, qt.Equals, metricDERPTxPackets) - c.Assert(physIPv4RxPackets, qt.Equals, metricIPv4RxPackets) - c.Assert(physIPv4TxPackets, qt.Equals, metricIPv4TxPackets) - - // Validate that the usermetrics and clientmetrics are in sync - // Note: the clientmetrics are global, this means that when they are registering with the - // wgengine, multiple in-process nodes used by this test will be updating the same metrics. This is why we need to multiply - // the metrics by 2 to get the expected value. - // TODO(kradalby): https://github.com/tailscale/tailscale/issues/13420 - c.Assert(metricSendUDP.Value(), qt.Equals, metricIPv4TxPackets*2) - c.Assert(metricSendDataPacketsIPv4.Value(), qt.Equals, metricIPv4TxPackets*2) - c.Assert(metricSendDataPacketsDERP.Value(), qt.Equals, metricDERPTxPackets*2) - c.Assert(metricSendDataBytesIPv4.Value(), qt.Equals, metricIPv4TxBytes*2) - c.Assert(metricSendDataBytesDERP.Value(), qt.Equals, metricDERPTxBytes*2) - c.Assert(metricRecvDataPacketsIPv4.Value(), qt.Equals, metricIPv4RxPackets*2) - c.Assert(metricRecvDataPacketsDERP.Value(), qt.Equals, metricDERPRxPackets*2) - c.Assert(metricRecvDataBytesIPv4.Value(), qt.Equals, metricIPv4RxBytes*2) - c.Assert(metricRecvDataBytesDERP.Value(), qt.Equals, metricDERPRxBytes*2) + m1m := m1.conn.metrics + m2m := m2.conn.metrics + + // metricSendUDP aggregates outboundPacketsIPv4Total + outboundPacketsIPv6Total + c.Assert(metricSendUDP.Value(), qt.Equals, + m1m.outboundPacketsIPv4Total.Value()+m1m.outboundPacketsIPv6Total.Value()+ + m2m.outboundPacketsIPv4Total.Value()+m2m.outboundPacketsIPv6Total.Value()) + c.Assert(metricSendDataPacketsIPv4.Value(), qt.Equals, + m1m.outboundPacketsIPv4Total.Value()+m2m.outboundPacketsIPv4Total.Value()) + c.Assert(metricSendDataPacketsDERP.Value(), qt.Equals, + m1m.outboundPacketsDERPTotal.Value()+m2m.outboundPacketsDERPTotal.Value()) + c.Assert(metricSendDataBytesIPv4.Value(), qt.Equals, + m1m.outboundBytesIPv4Total.Value()+m2m.outboundBytesIPv4Total.Value()) + c.Assert(metricSendDataBytesDERP.Value(), qt.Equals, + m1m.outboundBytesDERPTotal.Value()+m2m.outboundBytesDERPTotal.Value()) + c.Assert(metricRecvDataPacketsIPv4.Value(), qt.Equals, + m1m.inboundPacketsIPv4Total.Value()+m2m.inboundPacketsIPv4Total.Value()) + c.Assert(metricRecvDataPacketsDERP.Value(), qt.Equals, + m1m.inboundPacketsDERPTotal.Value()+m2m.inboundPacketsDERPTotal.Value()) + c.Assert(metricRecvDataBytesIPv4.Value(), qt.Equals, + m1m.inboundBytesIPv4Total.Value()+m2m.inboundBytesIPv4Total.Value()) + c.Assert(metricRecvDataBytesDERP.Value(), qt.Equals, + m1m.inboundBytesDERPTotal.Value()+m2m.inboundBytesDERPTotal.Value()) } // tests that having a endpoint.String prevents wireguard-go's @@ -1371,13 +1461,8 @@ func TestDiscoStringLogRace(t *testing.T) { } func Test32bitAlignment(t *testing.T) { - // Need an associated conn with non-nil noteRecvActivity to - // trigger interesting work on the atomics in endpoint. - called := 0 de := endpoint{ - c: &Conn{ - noteRecvActivity: func(key.NodePublic) { called++ }, - }, + c: &Conn{}, } if off := unsafe.Offsetof(de.lastRecvWG); off%8 != 0 { @@ -1385,19 +1470,12 @@ func Test32bitAlignment(t *testing.T) { } de.noteRecvActivity(epAddr{}, mono.Now()) // verify this doesn't panic on 32-bit - if called != 1 { - t.Fatal("expected call to noteRecvActivity") - } - de.noteRecvActivity(epAddr{}, mono.Now()) - if called != 1 { - t.Error("expected no second call to noteRecvActivity") - } + de.noteRecvActivity(epAddr{}, mono.Now()) // second call should be throttled } // newTestConn returns a new Conn. func newTestConn(t testing.TB) *Conn { t.Helper() - port := pickPort(t) bus := eventbustest.NewBus(t) @@ -1414,7 +1492,7 @@ func newTestConn(t testing.TB) *Conn { Metrics: new(usermetric.Registry), DisablePortMapper: true, Logf: t.Logf, - Port: port, + Port: 0, TestOnlyPacketListener: localhostListener{}, EndpointsFunc: func(eps []tailcfg.Endpoint) { t.Logf("endpoints: %q", eps) @@ -2008,7 +2086,7 @@ func TestStressSetNetworkMap(t *testing.T) { const iters = 1000 // approx 0.5s on an m1 mac for range iters { - for j := 0; j < npeers; j++ { + for j := range npeers { // Randomize which peers are present. if prng.Int()&1 == 0 { present[j] = !present[j] @@ -2197,7 +2275,7 @@ func newWireguard(t *testing.T, uapi string, aips []netip.Prefix) (*device.Devic if err != nil { t.Fatal(err) } - for _, line := range strings.Split(s, "\n") { + for line := range strings.SplitSeq(s, "\n") { line = strings.TrimSpace(line) if len(line) == 0 { continue @@ -2309,7 +2387,7 @@ func TestIsWireGuardOnlyPeerWithMasquerade(t *testing.T) { IsWireGuardOnly: true, Addresses: []netip.Prefix{wgaip}, AllowedIPs: []netip.Prefix{wgaip}, - SelfNodeV4MasqAddrForThisPeer: ptr.To(masqip.Addr()), + SelfNodeV4MasqAddrForThisPeer: new(masqip.Addr()), }, }), } @@ -2693,7 +2771,7 @@ func TestAddrForSendLockedForWireGuardOnly(t *testing.T) { want epAddr }{ { - name: "no endpoints", + name: "no-endpoints", sendInitialPing: false, validAddr: false, sendFollowUpPing: false, @@ -2702,7 +2780,7 @@ func TestAddrForSendLockedForWireGuardOnly(t *testing.T) { want: epAddr{}, }, { - name: "singular endpoint does not request ping", + name: "singular-endpoint-no-ping-request", sendInitialPing: false, validAddr: true, sendFollowUpPing: false, @@ -2716,7 +2794,7 @@ func TestAddrForSendLockedForWireGuardOnly(t *testing.T) { want: epAddr{ap: netip.MustParseAddrPort("1.1.1.1:111")}, }, { - name: "ping sent within wireguardPingInterval should not request ping", + name: "ping-within-wireguardPingInterval-no-request", sendInitialPing: true, validAddr: true, sendFollowUpPing: false, @@ -2734,7 +2812,7 @@ func TestAddrForSendLockedForWireGuardOnly(t *testing.T) { want: epAddr{ap: netip.MustParseAddrPort("1.1.1.1:111")}, }, { - name: "ping sent outside of wireguardPingInterval should request ping", + name: "ping-outside-wireguardPingInterval-requests-ping", sendInitialPing: true, validAddr: true, sendFollowUpPing: true, @@ -2752,7 +2830,7 @@ func TestAddrForSendLockedForWireGuardOnly(t *testing.T) { want: epAddr{ap: netip.MustParseAddrPort("1.1.1.1:111")}, }, { - name: "choose lowest latency for useable IPv4 and IPv6", + name: "choose-lowest-latency-v4-and-v6", sendInitialPing: true, validAddr: true, sendFollowUpPing: false, @@ -2770,7 +2848,7 @@ func TestAddrForSendLockedForWireGuardOnly(t *testing.T) { want: epAddr{ap: netip.MustParseAddrPort("[2345:0425:2CA1:0000:0000:0567:5673:23b5]:222")}, }, { - name: "choose IPv6 address when latency is the same for v4 and v6", + name: "choose-IPv6-when-equal-latency", sendInitialPing: true, validAddr: true, sendFollowUpPing: false, @@ -3010,6 +3088,7 @@ func TestMaybeSetNearestDERP(t *testing.T) { old int reportDERP int connectedToControl bool + force bool want int }{ { @@ -3033,6 +3112,22 @@ func TestMaybeSetNearestDERP(t *testing.T) { connectedToControl: false, // not connected... want: 21, // ... but want to change to new DERP }, + { + name: "force_not_connected_with_report_derp", + old: 1, + reportDERP: 21, + connectedToControl: false, + force: true, + want: 21, // force bypasses the no-change-without-control guard + }, + { + name: "force_not_connected_no_derp_no_current", + old: 0, + reportDERP: 0, + connectedToControl: false, + force: true, + want: 31, // force + no report DERP → deterministic fallback + }, { name: "not_connected_with_fallback_and_no_current", old: 0, // no current DERP @@ -3057,8 +3152,13 @@ func TestMaybeSetNearestDERP(t *testing.T) { } for _, tt := range testCases { t.Run(tt.name, func(t *testing.T) { - ht := health.NewTracker(eventbustest.NewBus(t)) + bus := eventbustest.NewBus(t) + ht := health.NewTracker(bus) c := newConn(t.Logf) + ec := bus.Client("magicsock.Conn.Test") + c.eventClient = ec + c.homeDERPChangedPub = eventbus.Publish[HomeDERPChanged](ec) + c.eventBus = bus c.myDerp = tt.old c.derpMap = derpMap c.health = ht @@ -3076,7 +3176,7 @@ func TestMaybeSetNearestDERP(t *testing.T) { } } - got := c.maybeSetNearestDERP(report) + got := c.maybeSetNearestDERP(report, tt.force) if got != tt.want { t.Errorf("got new DERP region %d, want %d", got, tt.want) } @@ -3203,6 +3303,8 @@ func TestNetworkSendErrors(t *testing.T) { t.Skipf("skipping on %s", runtime.GOOS) } + tstest.Replace(t, &checkNetworkDownDuringTests, true) + conn, reg, close := newTestConnAndRegistry(t) defer close() @@ -3338,73 +3440,73 @@ func Test_packetLooksLike(t *testing.T) { wantIsGeneveEncap bool }{ { - name: "STUN binding success response", + name: "STUN-binding-success-response", msg: stun.Response(stun.NewTxID(), netip.MustParseAddrPort("127.0.0.1:1")), wantPacketLooksLikeType: packetLooksLikeSTUNBinding, wantIsGeneveEncap: false, }, { - name: "naked disco", + name: "naked-disco", msg: nakedDisco, wantPacketLooksLikeType: packetLooksLikeDisco, wantIsGeneveEncap: false, }, { - name: "geneve encap disco", + name: "geneve-encap-disco", msg: geneveEncapDisco, wantPacketLooksLikeType: packetLooksLikeDisco, wantIsGeneveEncap: true, }, { - name: "geneve encap too short disco", + name: "geneve-encap-too-short-disco", msg: geneveEncapDisco[:len(geneveEncapDisco)-key.DiscoPublicRawLen], wantPacketLooksLikeType: packetLooksLikeWireGuard, wantIsGeneveEncap: false, }, { - name: "geneve encap disco nonzero geneve version", + name: "geneve-encap-disco-nonzero-geneve-version", msg: geneveEncapDiscoNonZeroGeneveVersion, wantPacketLooksLikeType: packetLooksLikeWireGuard, wantIsGeneveEncap: false, }, { - name: "geneve encap disco nonzero geneve reserved bits", + name: "geneve-encap-disco-nonzero-geneve-reserved-bits", msg: geneveEncapDiscoNonZeroGeneveReservedBits, wantPacketLooksLikeType: packetLooksLikeWireGuard, wantIsGeneveEncap: false, }, { - name: "geneve encap disco nonzero geneve vni lsb", + name: "geneve-encap-disco-nonzero-geneve-vni-lsb", msg: geneveEncapDiscoNonZeroGeneveVNILSB, wantPacketLooksLikeType: packetLooksLikeWireGuard, wantIsGeneveEncap: false, }, { - name: "geneve encap wireguard", + name: "geneve-encap-wireguard", msg: geneveEncapWireGuard, wantPacketLooksLikeType: packetLooksLikeWireGuard, wantIsGeneveEncap: true, }, { - name: "naked WireGuard Initiation type", + name: "naked-WireGuard-Initiation-type", msg: nakedWireGuardInitiation, wantPacketLooksLikeType: packetLooksLikeWireGuard, wantIsGeneveEncap: false, }, { - name: "naked WireGuard Response type", + name: "naked-WireGuard-Response-type", msg: nakedWireGuardResponse, wantPacketLooksLikeType: packetLooksLikeWireGuard, wantIsGeneveEncap: false, }, { - name: "naked WireGuard Cookie Reply type", + name: "naked-WireGuard-Cookie-Reply-type", msg: nakedWireGuardCookieReply, wantPacketLooksLikeType: packetLooksLikeWireGuard, wantIsGeneveEncap: false, }, { - name: "naked WireGuard Transport type", + name: "naked-WireGuard-Transport-type", msg: nakedWireGuardTransport, wantPacketLooksLikeType: packetLooksLikeWireGuard, wantIsGeneveEncap: false, @@ -3441,22 +3543,22 @@ func Test_looksLikeInitiationMsg(t *testing.T) { want bool }{ { - name: "valid initiation", + name: "valid-initiation", b: initMsg, want: true, }, { - name: "invalid message type field", + name: "invalid-message-type-field", b: initMsgSizeTransportType, want: false, }, { - name: "too small", + name: "too-small", b: initMsg[:device.MessageInitiationSize-1], want: false, }, { - name: "too big", + name: "too-big", b: append(initMsg, 0), want: false, }, @@ -3498,7 +3600,7 @@ func Test_nodeHasCap(t *testing.T) { want bool }{ { - name: "match v4", + name: "match-v4", filt: filter.New([]filtertype.Match{ { Srcs: []netip.Prefix{netip.MustParsePrefix("2.2.2.2/32")}, @@ -3516,7 +3618,7 @@ func Test_nodeHasCap(t *testing.T) { want: true, }, { - name: "match v6", + name: "match-v6", filt: filter.New([]filtertype.Match{ { Srcs: []netip.Prefix{netip.MustParsePrefix("::2/128")}, @@ -3534,7 +3636,7 @@ func Test_nodeHasCap(t *testing.T) { want: true, }, { - name: "no match CapMatch Dst", + name: "no-match-CapMatch-Dst", filt: filter.New([]filtertype.Match{ { Srcs: []netip.Prefix{netip.MustParsePrefix("::2/128")}, @@ -3552,7 +3654,7 @@ func Test_nodeHasCap(t *testing.T) { want: false, }, { - name: "no match peer cap", + name: "no-match-peer-cap", filt: filter.New([]filtertype.Match{ { Srcs: []netip.Prefix{netip.MustParsePrefix("::2/128")}, @@ -3570,7 +3672,7 @@ func Test_nodeHasCap(t *testing.T) { want: false, }, { - name: "nil src", + name: "nil-src", filt: filter.New([]filtertype.Match{ { Srcs: []netip.Prefix{netip.MustParsePrefix("2.2.2.2/32")}, @@ -3588,7 +3690,7 @@ func Test_nodeHasCap(t *testing.T) { want: false, }, { - name: "nil dst", + name: "nil-dst", filt: filter.New([]filtertype.Match{ { Srcs: []netip.Prefix{netip.MustParsePrefix("2.2.2.2/32")}, @@ -3666,7 +3768,7 @@ func TestConn_SetNetworkMap_updateRelayServersSet(t *testing.T) { wantRelayClientEnabled bool }{ { - name: "candidate relay server", + name: "candidate-relay-server", filt: filter.New([]filtertype.Match{ { Srcs: peerNodeCandidateRelay.Addresses, @@ -3690,7 +3792,7 @@ func TestConn_SetNetworkMap_updateRelayServersSet(t *testing.T) { wantRelayClientEnabled: true, }, { - name: "no candidate relay server because self has tailcfg.NodeAttrDisableRelayClient", + name: "no-candidate-self-has-DisableRelayClient", // self has tailcfg.NodeAttrDisableRelayClient filt: filter.New([]filtertype.Match{ { Srcs: peerNodeCandidateRelay.Addresses, @@ -3708,7 +3810,7 @@ func TestConn_SetNetworkMap_updateRelayServersSet(t *testing.T) { wantRelayClientEnabled: false, }, { - name: "no candidate relay server because self has tailcfg.NodeAttrOnlyTCP443", + name: "no-candidate-self-has-OnlyTCP443", // self has tailcfg.NodeAttrOnlyTCP443 filt: filter.New([]filtertype.Match{ { Srcs: peerNodeCandidateRelay.Addresses, @@ -3726,7 +3828,7 @@ func TestConn_SetNetworkMap_updateRelayServersSet(t *testing.T) { wantRelayClientEnabled: false, }, { - name: "self candidate relay server", + name: "self-candidate-relay-server", filt: filter.New([]filtertype.Match{ { Srcs: selfNode.Addresses, @@ -3750,7 +3852,7 @@ func TestConn_SetNetworkMap_updateRelayServersSet(t *testing.T) { wantRelayClientEnabled: true, }, { - name: "no candidate relay server", + name: "no-candidate-relay-server", filt: filter.New([]filtertype.Match{ { Srcs: peerNodeNotCandidateRelayCapVer.Addresses, @@ -3774,7 +3876,7 @@ func TestConn_SetNetworkMap_updateRelayServersSet(t *testing.T) { c.filt = tt.filt if len(tt.wantRelayServers) == 0 { // So we can verify it gets flipped back. - c.hasPeerRelayServers.Store(true) + c.relayManager.hasPeerRelayServers.Store(true) } c.SetNetworkMap(tt.self, tt.peers) @@ -3782,8 +3884,8 @@ func TestConn_SetNetworkMap_updateRelayServersSet(t *testing.T) { if !got.Equal(tt.wantRelayServers) { t.Fatalf("got: %v != want: %v", got, tt.wantRelayServers) } - if len(tt.wantRelayServers) > 0 != c.hasPeerRelayServers.Load() { - t.Fatalf("c.hasPeerRelayServers: %v != len(tt.wantRelayServers) > 0: %v", c.hasPeerRelayServers.Load(), len(tt.wantRelayServers) > 0) + if got, want := c.relayManager.hasPeerRelayServers.Load(), len(tt.wantRelayServers) > 0; got != want { + t.Fatalf("c.relayManager.hasPeerRelayServers: %v != len(tt.wantRelayServers) > 0: %v", got, want) } if c.relayClientEnabled != tt.wantRelayClientEnabled { t.Fatalf("c.relayClientEnabled: %v != wantRelayClientEnabled: %v", c.relayClientEnabled, tt.wantRelayClientEnabled) @@ -3863,63 +3965,58 @@ func TestConn_receiveIP(t *testing.T) { // If [*endpoint] then we expect 'got' to be the same [*endpoint]. If // [*lazyEndpoint] and [*lazyEndpoint.maybeEP] is non-nil, we expect // got.maybeEP to also be non-nil. Must not be reused across tests. - wantEndpointType wgconn.Endpoint - wantSize int - wantIsGeneveEncap bool - wantOk bool - wantMetricInc *clientmetric.Metric - wantNoteRecvActivityCalled bool + wantEndpointType wgconn.Endpoint + wantSize int + wantIsGeneveEncap bool + wantOk bool + wantMetricInc *clientmetric.Metric }{ { - name: "naked disco", - b: looksLikeNakedDisco, - ipp: netip.MustParseAddrPort("127.0.0.1:7777"), - cache: &epAddrEndpointCache{}, - wantEndpointType: nil, - wantSize: 0, - wantIsGeneveEncap: false, - wantOk: false, - wantMetricInc: metricRecvDiscoBadPeer, - wantNoteRecvActivityCalled: false, + name: "naked-disco", + b: looksLikeNakedDisco, + ipp: netip.MustParseAddrPort("127.0.0.1:7777"), + cache: &epAddrEndpointCache{}, + wantEndpointType: nil, + wantSize: 0, + wantIsGeneveEncap: false, + wantOk: false, + wantMetricInc: metricRecvDiscoBadPeer, }, { - name: "geneve encap disco", - b: looksLikeGeneveDisco, - ipp: netip.MustParseAddrPort("127.0.0.1:7777"), - cache: &epAddrEndpointCache{}, - wantEndpointType: nil, - wantSize: 0, - wantIsGeneveEncap: false, - wantOk: false, - wantMetricInc: metricRecvDiscoBadPeer, - wantNoteRecvActivityCalled: false, + name: "geneve-encap-disco", + b: looksLikeGeneveDisco, + ipp: netip.MustParseAddrPort("127.0.0.1:7777"), + cache: &epAddrEndpointCache{}, + wantEndpointType: nil, + wantSize: 0, + wantIsGeneveEncap: false, + wantOk: false, + wantMetricInc: metricRecvDiscoBadPeer, }, { - name: "STUN binding", - b: looksLikeSTUNBinding, - ipp: netip.MustParseAddrPort("127.0.0.1:7777"), - cache: &epAddrEndpointCache{}, - wantEndpointType: nil, - wantSize: 0, - wantIsGeneveEncap: false, - wantOk: false, - wantMetricInc: findMetricByName("netcheck_stun_recv_ipv4"), - wantNoteRecvActivityCalled: false, + name: "STUN-binding", + b: looksLikeSTUNBinding, + ipp: netip.MustParseAddrPort("127.0.0.1:7777"), + cache: &epAddrEndpointCache{}, + wantEndpointType: nil, + wantSize: 0, + wantIsGeneveEncap: false, + wantOk: false, + wantMetricInc: findMetricByName("netcheck_stun_recv_ipv4"), }, { - name: "naked WireGuard init lazyEndpoint empty peerMap", - b: looksLikeNakedWireGuardInit, - ipp: netip.MustParseAddrPort("127.0.0.1:7777"), - cache: &epAddrEndpointCache{}, - wantEndpointType: &lazyEndpoint{}, - wantSize: len(looksLikeNakedWireGuardInit), - wantIsGeneveEncap: false, - wantOk: true, - wantMetricInc: nil, - wantNoteRecvActivityCalled: false, + name: "naked-WireGuard-init-lazyEndpoint-empty-peerMap", + b: looksLikeNakedWireGuardInit, + ipp: netip.MustParseAddrPort("127.0.0.1:7777"), + cache: &epAddrEndpointCache{}, + wantEndpointType: &lazyEndpoint{}, + wantSize: len(looksLikeNakedWireGuardInit), + wantIsGeneveEncap: false, + wantOk: true, + wantMetricInc: nil, }, { - name: "naked WireGuard init endpoint matching peerMap entry", + name: "naked-WireGuard-init-endpoint-matching-peerMap-entry", b: looksLikeNakedWireGuardInit, ipp: netip.MustParseAddrPort("127.0.0.1:7777"), cache: &epAddrEndpointCache{}, @@ -3930,22 +4027,20 @@ func TestConn_receiveIP(t *testing.T) { wantIsGeneveEncap: false, wantOk: true, wantMetricInc: nil, - wantNoteRecvActivityCalled: true, }, { - name: "geneve WireGuard init lazyEndpoint empty peerMap", - b: looksLikeGeneveWireGuardInit, - ipp: netip.MustParseAddrPort("127.0.0.1:7777"), - cache: &epAddrEndpointCache{}, - wantEndpointType: &lazyEndpoint{}, - wantSize: len(looksLikeGeneveWireGuardInit) - packet.GeneveFixedHeaderLength, - wantIsGeneveEncap: true, - wantOk: true, - wantMetricInc: nil, - wantNoteRecvActivityCalled: false, + name: "geneve-WireGuard-init-lazyEndpoint-empty-peerMap", + b: looksLikeGeneveWireGuardInit, + ipp: netip.MustParseAddrPort("127.0.0.1:7777"), + cache: &epAddrEndpointCache{}, + wantEndpointType: &lazyEndpoint{}, + wantSize: len(looksLikeGeneveWireGuardInit) - packet.GeneveFixedHeaderLength, + wantIsGeneveEncap: true, + wantOk: true, + wantMetricInc: nil, }, { - name: "geneve WireGuard init lazyEndpoint matching peerMap activity noted", + name: "geneve-WireGuard-init-lazyEndpoint-matching-peerMap-activity-noted", b: looksLikeGeneveWireGuardInit, ipp: netip.MustParseAddrPort("127.0.0.1:7777"), cache: &epAddrEndpointCache{}, @@ -3954,14 +4049,13 @@ func TestConn_receiveIP(t *testing.T) { wantEndpointType: &lazyEndpoint{ maybeEP: newPeerMapInsertableEndpoint(0), }, - wantSize: len(looksLikeGeneveWireGuardInit) - packet.GeneveFixedHeaderLength, - wantIsGeneveEncap: true, - wantOk: true, - wantMetricInc: nil, - wantNoteRecvActivityCalled: true, + wantSize: len(looksLikeGeneveWireGuardInit) - packet.GeneveFixedHeaderLength, + wantIsGeneveEncap: true, + wantOk: true, + wantMetricInc: nil, }, { - name: "geneve WireGuard init lazyEndpoint matching peerMap no activity noted", + name: "geneve-WireGuard-init-lazyEndpoint-matching-peerMap-no-activity-noted", b: looksLikeGeneveWireGuardInit, ipp: netip.MustParseAddrPort("127.0.0.1:7777"), cache: &epAddrEndpointCache{}, @@ -3970,17 +4064,15 @@ func TestConn_receiveIP(t *testing.T) { wantEndpointType: &lazyEndpoint{ maybeEP: newPeerMapInsertableEndpoint(mono.Now().Add(time.Hour * 24)), }, - wantSize: len(looksLikeGeneveWireGuardInit) - packet.GeneveFixedHeaderLength, - wantIsGeneveEncap: true, - wantOk: true, - wantMetricInc: nil, - wantNoteRecvActivityCalled: false, + wantSize: len(looksLikeGeneveWireGuardInit) - packet.GeneveFixedHeaderLength, + wantIsGeneveEncap: true, + wantOk: true, + wantMetricInc: nil, }, // TODO(jwhited): verify cache.de is used when conditions permit } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - noteRecvActivityCalled := false metricBefore := int64(0) if tt.wantMetricInc != nil { metricBefore = tt.wantMetricInc.Value() @@ -3993,9 +4085,6 @@ func TestConn_receiveIP(t *testing.T) { peerMap: newPeerMap(), } c.havePrivateKey.Store(true) - c.noteRecvActivity = func(public key.NodePublic) { - noteRecvActivityCalled = true - } var counts netlogtype.CountsByConnection c.SetConnectionCounter(counts.Add) @@ -4050,10 +4139,6 @@ func TestConn_receiveIP(t *testing.T) { if tt.wantMetricInc != nil && tt.wantMetricInc.Value() != metricBefore+1 { t.Errorf("receiveIP() metric %v not incremented", tt.wantMetricInc.Name()) } - if tt.wantNoteRecvActivityCalled != noteRecvActivityCalled { - t.Errorf("receiveIP() noteRecvActivityCalled = %v, want %v", noteRecvActivityCalled, tt.wantNoteRecvActivityCalled) - } - if tt.cache.de != nil { switch ep := got.(type) { case *endpoint: @@ -4105,34 +4190,29 @@ func TestConn_receiveIP(t *testing.T) { func Test_lazyEndpoint_InitiationMessagePublicKey(t *testing.T) { tests := []struct { - name string - callWithPeerMapKey bool - maybeEPMatchingKey bool - wantNoteRecvActivityCalled bool + name string + callWithPeerMapKey bool + maybeEPMatchingKey bool }{ { - name: "noteRecvActivity called", - callWithPeerMapKey: true, - maybeEPMatchingKey: false, - wantNoteRecvActivityCalled: true, + name: "noteRecvActivity-called", + callWithPeerMapKey: true, + maybeEPMatchingKey: false, }, { - name: "maybeEP early return", - callWithPeerMapKey: true, - maybeEPMatchingKey: true, - wantNoteRecvActivityCalled: false, + name: "maybeEP-early-return", + callWithPeerMapKey: true, + maybeEPMatchingKey: true, }, { - name: "not in peerMap early return", - callWithPeerMapKey: false, - maybeEPMatchingKey: false, - wantNoteRecvActivityCalled: false, + name: "not-in-peerMap-early-return", + callWithPeerMapKey: false, + maybeEPMatchingKey: false, }, { - name: "not in peerMap maybeEP early return", - callWithPeerMapKey: false, - maybeEPMatchingKey: true, - wantNoteRecvActivityCalled: false, + name: "not-in-peerMap-maybeEP-early-return", + callWithPeerMapKey: false, + maybeEPMatchingKey: true, }, } for _, tt := range tests { @@ -4145,19 +4225,7 @@ func Test_lazyEndpoint_InitiationMessagePublicKey(t *testing.T) { key: key.NewDisco().Public(), }) - var noteRecvActivityCalledFor key.NodePublic conn := newConn(t.Logf) - conn.noteRecvActivity = func(public key.NodePublic) { - // wireguard-go will call into ParseEndpoint if the "real" - // noteRecvActivity ends up JIT configuring the peer. Mimic that - // to ensure there are no deadlocks around conn.mu. - // See tailscale/tailscale#16651 & http://go/corp#30836 - _, err := conn.ParseEndpoint(ep.publicKey.UntypedHexString()) - if err != nil { - t.Fatalf("ParseEndpoint() err: %v", err) - } - noteRecvActivityCalledFor = public - } ep.c = conn var pubKey [32]byte @@ -4173,13 +4241,6 @@ func Test_lazyEndpoint_InitiationMessagePublicKey(t *testing.T) { le.maybeEP = ep } le.InitiationMessagePublicKey(pubKey) - want := key.NodePublic{} - if tt.wantNoteRecvActivityCalled { - want = ep.publicKey - } - if noteRecvActivityCalledFor.Compare(want) != 0 { - t.Fatalf("noteRecvActivityCalledFor = %v, want %v", noteRecvActivityCalledFor, want) - } }) } } @@ -4192,25 +4253,25 @@ func Test_lazyEndpoint_FromPeer(t *testing.T) { wantEpAddrInPeerMap bool }{ { - name: "epAddr in peerMap", + name: "epAddr-in-peerMap", callWithPeerMapKey: true, maybeEPMatchingKey: false, wantEpAddrInPeerMap: true, }, { - name: "maybeEP early return", + name: "maybeEP-early-return", callWithPeerMapKey: true, maybeEPMatchingKey: true, wantEpAddrInPeerMap: false, }, { - name: "not in peerMap early return", + name: "not-in-peerMap-early-return", callWithPeerMapKey: false, maybeEPMatchingKey: false, wantEpAddrInPeerMap: false, }, { - name: "not in peerMap maybeEP early return", + name: "not-in-peerMap-maybeEP-early-return", callWithPeerMapKey: false, maybeEPMatchingKey: true, wantEpAddrInPeerMap: false, @@ -4312,7 +4373,7 @@ func TestRotateDiscoKeyMultipleTimes(t *testing.T) { keys := make([]key.DiscoPublic, 0, 5) keys = append(keys, c.discoAtomic.Public()) - for i := 0; i < 4; i++ { + for i := range 4 { c.RotateDiscoKey() newKey := c.discoAtomic.Public() @@ -4349,7 +4410,7 @@ func TestReceiveTSMPDiscoKeyAdvertisement(t *testing.T) { netip.MustParsePrefix("100.64.0.1/32"), }, }).View() - conn.peers = views.SliceOf([]tailcfg.NodeView{nodeView}) + conn.peersByID = map[tailcfg.NodeID]tailcfg.NodeView{nodeView.ID(): nodeView} conn.mu.Unlock() conn.peerMap.upsertEndpoint(ep, key.DiscoPublic{}) @@ -4369,3 +4430,70 @@ func TestReceiveTSMPDiscoKeyAdvertisement(t *testing.T) { t.Errorf("New disco key %s, does not match %s", newDiscoKey.ShortString(), ep.disco.Load().short) } } + +func TestSendingTSMPDiscoTimer(t *testing.T) { + conn := newTestConn(t) + tw := eventbustest.NewWatcher(t, conn.eventBus) + t.Cleanup(func() { conn.Close() }) + + peerKey := key.NewNode().Public() + ep := &endpoint{ + nodeID: 1, + publicKey: peerKey, + nodeAddr: netip.MustParseAddr("100.64.0.1"), + } + discoKey := key.NewDisco().Public() + ep.disco.Store(&endpointDisco{ + key: discoKey, + short: discoKey.ShortString(), + }) + ep.c = conn + conn.mu.Lock() + nodeView := (&tailcfg.Node{ + Key: ep.publicKey, + Addresses: []netip.Prefix{ + netip.MustParsePrefix("100.64.0.1/32"), + }, + }).View() + conn.peersByID = map[tailcfg.NodeID]tailcfg.NodeView{nodeView.ID(): nodeView} + conn.mu.Unlock() + + conn.peerMap.upsertEndpoint(ep, key.DiscoPublic{}) + + if ep.discoShort() != discoKey.ShortString() { + t.Errorf("Original disco key %s, does not match %s", discoKey.ShortString(), ep.discoShort()) + } + + // Only one gets through, second is rate limited. + conn.maybeSendTSMPDiscoAdvert(ep) + conn.maybeSendTSMPDiscoAdvert(ep) + if err := eventbustest.ExpectExactly(tw, eventbustest.Type[NewDiscoKeyAvailable]()); err != nil { + t.Errorf("expected only one event, got: %s", err) + } + + // Reset to get the event firing again. + ep.mu.Lock() + ep.lastDiscoKeyAdvertisement = 0 + ep.mu.Unlock() + conn.maybeSendTSMPDiscoAdvert(ep) + if err := eventbustest.Expect(tw, eventbustest.Type[NewDiscoKeyAvailable]()); err != nil { + t.Errorf("expected only one event, got: %s", err) + } + + // With a direct bestAddr and a non-zero lastDiscoKeyAdvertisement past the + // rate-limit interval. No advert should be sent due to the active bestAddr. + ep.mu.Lock() + ep.lastDiscoKeyAdvertisement = mono.Now().Add(-discoKeyAdvertisementInterval - time.Second) + ep.bestAddr = addrQuality{epAddr: epAddr{ap: netip.MustParseAddrPort("1.2.3.4:567")}} + ep.mu.Unlock() + conn.maybeSendTSMPDiscoAdvert(ep) + + // Simulating restart should send an advert. + ep.mu.Lock() + ep.lastDiscoKeyAdvertisement = 0 + ep.mu.Unlock() + conn.maybeSendTSMPDiscoAdvert(ep) + if err := eventbustest.ExpectExactly(tw, eventbustest.Type[NewDiscoKeyAvailable]()); err != nil { + t.Errorf("expected only one event, got: %s", err) + } +} diff --git a/wgengine/magicsock/rebinding_conn.go b/wgengine/magicsock/rebinding_conn.go index e00eed1f5c88c..11398c5925617 100644 --- a/wgengine/magicsock/rebinding_conn.go +++ b/wgengine/magicsock/rebinding_conn.go @@ -43,7 +43,7 @@ type RebindingUDPConn struct { // disrupting surrounding code that assumes nettype.PacketConn is a // *net.UDPConn. func (c *RebindingUDPConn) setConnLocked(p nettype.PacketConn, network string, batchSize int) { - upc := batching.TryUpgradeToConn(p, network, batchSize) + upc := batching.TryUpgradeToConn(p, network, batchSize, "magicsock_udp_rxq_overflows") c.pconn = upc c.pconnAtomic.Store(&upc) c.port = uint16(c.localAddrLocked().Port) diff --git a/wgengine/magicsock/relaymanager.go b/wgengine/magicsock/relaymanager.go index e4cd5eb9ff537..8ea15bce30199 100644 --- a/wgengine/magicsock/relaymanager.go +++ b/wgengine/magicsock/relaymanager.go @@ -9,6 +9,7 @@ import ( "fmt" "net/netip" "sync" + "sync/atomic" "time" "tailscale.com/disco" @@ -34,6 +35,14 @@ import ( type relayManager struct { initOnce sync.Once + // hasPeerRelayServers is whether relayManager is configured with at + // least one peer relay server via [relayManager.handleRelayServersSet] + // (or per-peer variants). Exposed as an atomic so [endpoint] hot paths + // can short-circuit when there are no relay servers without taking any + // lock or entering the run loop. Written only from runLoop() via + // [relayManager.publishHasServersRunLoop]. + hasPeerRelayServers atomic.Bool + // =================================================================== // The following fields are owned by a single goroutine, runLoop(). serversByNodeKey map[key.NodePublic]candidatePeerRelay @@ -56,6 +65,8 @@ type relayManager struct { newServerEndpointCh chan newRelayServerEndpointEvent rxDiscoMsgCh chan relayDiscoMsgEvent serversCh chan set.Set[candidatePeerRelay] + serverUpsertCh chan candidatePeerRelay + serverRemoveCh chan key.NodePublic getServersCh chan chan set.Set[candidatePeerRelay] derpHomeChangeCh chan derpHomeChangeEvent @@ -228,6 +239,16 @@ func (r *relayManager) runLoop() { if !r.hasActiveWorkRunLoop() { return } + case upsert := <-r.serverUpsertCh: + r.handleServerUpsertRunLoop(upsert) + if !r.hasActiveWorkRunLoop() { + return + } + case nk := <-r.serverRemoveCh: + r.handleServerRemoveRunLoop(nk) + if !r.hasActiveWorkRunLoop() { + return + } case getServersCh := <-r.getServersCh: r.handleGetServersRunLoop(getServersCh) if !r.hasActiveWorkRunLoop() { @@ -265,6 +286,34 @@ func (r *relayManager) handleServersUpdateRunLoop(update set.Set[candidatePeerRe for _, v := range update.Slice() { r.serversByNodeKey[v.nodeKey] = v } + r.publishHasServersRunLoop() +} + +// handleServerUpsertRunLoop inserts or updates cp in serversByNodeKey. It is +// the per-peer analog of [relayManager.handleServersUpdateRunLoop] used by +// [Conn.UpsertPeer]. +func (r *relayManager) handleServerUpsertRunLoop(cp candidatePeerRelay) { + r.serversByNodeKey[cp.nodeKey] = cp + r.publishHasServersRunLoop() +} + +// handleServerRemoveRunLoop deletes nk from serversByNodeKey. It is a no-op +// if nk isn't currently a known server. It is the per-peer analog of +// [relayManager.handleServersUpdateRunLoop] used by [Conn.RemovePeer] and by +// [Conn.UpsertPeer] when a peer is upserted with fields that make it no +// longer a relay candidate. +func (r *relayManager) handleServerRemoveRunLoop(nk key.NodePublic) { + if _, ok := r.serversByNodeKey[nk]; !ok { + return + } + delete(r.serversByNodeKey, nk) + r.publishHasServersRunLoop() +} + +// publishHasServersRunLoop updates [relayManager.hasPeerRelayServers] to +// reflect whether any relay servers are currently known. +func (r *relayManager) publishHasServersRunLoop() { + r.hasPeerRelayServers.Store(len(r.serversByNodeKey) > 0) } type relayDiscoMsgEvent struct { @@ -330,6 +379,8 @@ func (r *relayManager) init() { r.newServerEndpointCh = make(chan newRelayServerEndpointEvent) r.rxDiscoMsgCh = make(chan relayDiscoMsgEvent) r.serversCh = make(chan set.Set[candidatePeerRelay]) + r.serverUpsertCh = make(chan candidatePeerRelay) + r.serverRemoveCh = make(chan key.NodePublic) r.getServersCh = make(chan chan set.Set[candidatePeerRelay]) r.derpHomeChangeCh = make(chan derpHomeChangeEvent) r.runLoopStoppedCh = make(chan struct{}, 1) @@ -436,6 +487,21 @@ func (r *relayManager) handleRelayServersSet(servers set.Set[candidatePeerRelay] relayManagerInputEvent(r, nil, &r.serversCh, servers) } +// handleRelayServerUpsert is the O(1) per-peer variant of +// [relayManager.handleRelayServersSet]: it inserts or updates a single +// relay server entry. +func (r *relayManager) handleRelayServerUpsert(cp candidatePeerRelay) { + relayManagerInputEvent(r, nil, &r.serverUpsertCh, cp) +} + +// handleRelayServerRemove is the O(1) per-peer variant of +// [relayManager.handleRelayServersSet]: it removes a single relay server +// entry by node key. It is a no-op if nk is not currently a known relay +// server. +func (r *relayManager) handleRelayServerRemove(nk key.NodePublic) { + relayManagerInputEvent(r, nil, &r.serverRemoveCh, nk) +} + // relayManagerInputEvent initializes [relayManager] if necessary, starts // relayManager.runLoop() if it is not running, and writes 'event' on 'eventCh'. // diff --git a/wgengine/magicsock/relaymanager_test.go b/wgengine/magicsock/relaymanager_test.go index 7d773e381a7c4..47d935404b5fe 100644 --- a/wgengine/magicsock/relaymanager_test.go +++ b/wgengine/magicsock/relaymanager_test.go @@ -141,7 +141,7 @@ func TestRelayManager_handleNewServerEndpointRunLoop(t *testing.T) { }{ { // Test for http://go/corp/32978 - name: "eq server+ep neq VNI higher lamport", + name: "eq-server-ep-neq-VNI-higher-lamport", events: []newRelayServerEndpointEvent{ serverAendpointALamport1VNI1, serverAendpointALamport2VNI2, @@ -151,7 +151,7 @@ func TestRelayManager_handleNewServerEndpointRunLoop(t *testing.T) { }, }, { - name: "eq server+ep neq VNI lower lamport", + name: "eq-server-ep-neq-VNI-lower-lamport", events: []newRelayServerEndpointEvent{ serverAendpointALamport2VNI2, serverAendpointALamport1VNI1, @@ -161,7 +161,7 @@ func TestRelayManager_handleNewServerEndpointRunLoop(t *testing.T) { }, }, { - name: "eq server+vni neq ep lower lamport", + name: "eq-server-vni-neq-ep-lower-lamport", events: []newRelayServerEndpointEvent{ serverAendpointALamport2VNI2, serverAendpointBLamport1VNI2, @@ -171,7 +171,7 @@ func TestRelayManager_handleNewServerEndpointRunLoop(t *testing.T) { }, }, { - name: "eq server+vni neq ep higher lamport", + name: "eq-server-vni-neq-ep-higher-lamport", events: []newRelayServerEndpointEvent{ serverAendpointBLamport1VNI2, serverAendpointALamport2VNI2, @@ -181,7 +181,7 @@ func TestRelayManager_handleNewServerEndpointRunLoop(t *testing.T) { }, }, { - name: "eq server+endpoint+vni higher lamport", + name: "eq-server-endpoint-vni-higher-lamport", events: []newRelayServerEndpointEvent{ serverAendpointALamport1VNI1, serverAendpointALamport2VNI1, @@ -191,7 +191,7 @@ func TestRelayManager_handleNewServerEndpointRunLoop(t *testing.T) { }, }, { - name: "eq server+endpoint+vni lower lamport", + name: "eq-server-endpoint-vni-lower-lamport", events: []newRelayServerEndpointEvent{ serverAendpointALamport2VNI1, serverAendpointALamport1VNI1, @@ -201,7 +201,7 @@ func TestRelayManager_handleNewServerEndpointRunLoop(t *testing.T) { }, }, { - name: "eq endpoint+vni+lamport neq server", + name: "eq-endpoint-vni-lamport-neq-server", events: []newRelayServerEndpointEvent{ serverAendpointALamport1VNI1, serverBendpointALamport1VNI1, @@ -212,7 +212,7 @@ func TestRelayManager_handleNewServerEndpointRunLoop(t *testing.T) { }, }, { - name: "trusted last best with matching server", + name: "trusted-last-best-with-matching-server", events: []newRelayServerEndpointEvent{ serverAendpointALamport1VNI1LastBestMatching, }, diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index 59c2613451fa5..11900edbfa400 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -64,10 +64,12 @@ import ( const debugPackets = false // If non-zero, these override the values returned from the corresponding -// functions, below. +// functions, below. They are accessed atomically because background +// goroutines in the gVisor TCP stack read them while test cleanup +// goroutines may be restoring them concurrently. var ( - maxInFlightConnectionAttemptsForTest int - maxInFlightConnectionAttemptsPerClientForTest int + maxInFlightConnectionAttemptsForTest atomic.Int32 + maxInFlightConnectionAttemptsPerClientForTest atomic.Int32 ) // maxInFlightConnectionAttempts returns the global number of in-flight @@ -80,8 +82,8 @@ var ( // connection, so we want to ensure that we don't allow an unbounded number of // connections. func maxInFlightConnectionAttempts() int { - if n := maxInFlightConnectionAttemptsForTest; n > 0 { - return n + if n := maxInFlightConnectionAttemptsForTest.Load(); n > 0 { + return int(n) } if version.IsMobile() { @@ -106,8 +108,8 @@ func maxInFlightConnectionAttempts() int { // maxInFlightConnectionAttempts, but applies on a per-client basis // (i.e. keyed by the remote Tailscale IP). func maxInFlightConnectionAttemptsPerClient() int { - if n := maxInFlightConnectionAttemptsPerClientForTest; n > 0 { - return n + if n := maxInFlightConnectionAttemptsPerClientForTest.Load(); n > 0 { + return int(n) } // For now, allow each individual client at most 2/3rds of the global @@ -119,6 +121,22 @@ func maxInFlightConnectionAttemptsPerClient() int { var debugNetstack = envknob.RegisterBool("TS_DEBUG_NETSTACK") +// netstackKeepaliveIdle overrides the netstack default (~2h) TCP keepalive +// idle time for forwarded connections. When a tailnet peer goes away without +// closing its connections (pod deleted, peer removed from netmap, silent +// network partition), the forwardTCP io.Copy goroutines block until keepalive +// fires. Under high-churn forwarding — many short-lived peers, or peers +// holding thousands of proxied connections that drop at once — the 2h default +// lets stuck goroutines accumulate faster than they clear. Value is a Go +// duration, e.g. "60s". See tailscale/tailscale#4522. +var netstackKeepaliveIdle = envknob.RegisterDuration("TS_NETSTACK_KEEPALIVE_IDLE") + +// netstackKeepaliveInterval overrides the netstack default (75s) TCP keepalive +// probe interval for forwarded connections. Independent of +// netstackKeepaliveIdle; setting one without the other leaves the unset knob +// at the netstack default. Value is a Go duration, e.g. "15s". +var netstackKeepaliveInterval = envknob.RegisterDuration("TS_NETSTACK_KEEPALIVE_INTERVAL") + var ( serviceIP = tsaddr.TailscaleServiceIP() serviceIPv6 = tsaddr.TailscaleServiceIPv6() @@ -603,15 +621,25 @@ type LocalBackend = any // Start sets up all the handlers so netstack can start working. Implements // wgengine.FakeImpl. +// +// The provided LocalBackend interface can be either nil, for special case users +// of netstack that don't have a LocalBackend, or a non-nil +// *ipnlocal.LocalBackend. Any other type will cause Start to panic. +// +// Start currently (2026-03-11) never returns a non-nil error, but maybe it did +// in the past and maybe it will in the future. func (ns *Impl) Start(b LocalBackend) error { - if b == nil { - panic("nil LocalBackend interface") - } - lb := b.(*ipnlocal.LocalBackend) - if lb == nil { - panic("nil LocalBackend") + switch b := b.(type) { + case nil: + // No backend, so just continue with ns.lb unset. + case *ipnlocal.LocalBackend: + if b == nil { + panic("nil LocalBackend") + } + ns.lb = b + default: + panic(fmt.Sprintf("unexpected type for LocalBackend: %T", b)) } - ns.lb = lb tcpFwd := tcp.NewForwarder(ns.ipstack, tcpRXBufDefSize, maxInFlightConnectionAttempts(), ns.acceptTCP) udpFwd := udp.NewForwarder(ns.ipstack, ns.acceptUDPNoICMP) ns.ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, ns.wrapTCPProtocolHandler(tcpFwd.HandlePacket)) @@ -817,20 +845,27 @@ func (ns *Impl) handleLocalPackets(p *packet.Parsed, t *tstun.Wrapper, gro *gro. serviceName, isVIPServiceIP := ns.atomicIPVIPServiceMap.Load()[dst] switch { case dst == serviceIP || dst == serviceIPv6: - // We want to intercept some traffic to the "service IP" (e.g. - // 100.100.100.100 for IPv4). However, of traffic to the - // service IP, we only care about UDP 53, and TCP on port 53, - // 80, and 8080. - switch p.IPProto { - case ipproto.TCP: - if port := p.Dst.Port(); port != 53 && port != 80 && port != 8080 && !ns.isLoopbackPort(port) { - return filter.Accept, gro - } - case ipproto.UDP: - if port := p.Dst.Port(); port != 53 && !ns.isLoopbackPort(port) { - return filter.Accept, gro - } - } + // Traffic to the Tailscale service IP (100.100.100.100 / + // fd7a:115c:a1e0::53) is always terminated locally on this + // node; it must never be forwarded out over WireGuard to a + // peer. Netstack's TCP/UDP acceptors handle the ports we + // actually serve (UDP 53 MagicDNS, TCP 53/80/8080 for DNS, + // the web client, and Taildrive, plus any debug loopback + // port). Other ports are rejected cleanly by netstack: UDP + // closes the endpoint in acceptUDP, and TCP is RST'd by + // acceptTCP's hittingServiceIP guard. + // + // Previously we returned filter.Accept for TCP/UDP on any + // other port, which let the packet fall through to the ACL + // filter and ultimately wireguard-go, where no peer owns the + // quad-100 AllowedIP. That produced noisy "open-conn-track: + // timeout opening ...; no associated peer node" log lines + // (e.g. for stray traffic to 100.100.100.100:853 / DoT) and + // leaked quad-100 packets onto the tailnet. + // + // We now unconditionally absorb quad-100 into netstack here, + // regardless of IP protocol or port, so such traffic never + // reaches the conntrack / peer-routing layers. case isVIPServiceIP: // returns all active VIP services in a set, since the IPVIPServiceMap // contains inactive service IPs when node hosts the service, we need to @@ -1205,6 +1240,34 @@ func (ns *Impl) shouldProcessInbound(p *packet.Parsed, t *tstun.Wrapper) bool { return true } } + // check if there's a registered UDP endpoint for this service VIP + // This allows userspace UDP listeners (e.g., via tsnet.ListenPacket) to + // receive traffic on service VIP addresses. + if p.IPProto == ipproto.UDP { + var netProto tcpip.NetworkProtocolNumber + var id stack.TransportEndpointID + if p.Dst.Addr().Is4() { + netProto = ipv4.ProtocolNumber + id = stack.TransportEndpointID{ + LocalAddress: tcpip.AddrFrom4(p.Dst.Addr().As4()), + LocalPort: p.Dst.Port(), + RemoteAddress: tcpip.AddrFrom4(p.Src.Addr().As4()), + RemotePort: p.Src.Port(), + } + } else { + netProto = ipv6.ProtocolNumber + id = stack.TransportEndpointID{ + LocalAddress: tcpip.AddrFrom16(p.Dst.Addr().As16()), + LocalPort: p.Dst.Port(), + RemoteAddress: tcpip.AddrFrom16(p.Src.Addr().As16()), + RemotePort: p.Src.Port(), + } + } + ep := ns.ipstack.FindTransportEndpoint(netProto, udp.ProtocolNumber, id, nicID) + if ep != nil { + return true + } + } return false } if p.IPVersion == 6 && !isLocal && viaRange.Contains(dstIP) { @@ -1471,6 +1534,7 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) { dialIP := netaddrIPFromNetstackIP(reqDetails.LocalAddress) isTailscaleIP := tsaddr.IsTailscaleIP(dialIP) + isLocal := ns.isLocalIP(dialIP) // i.e. not a subnet routed or 4via6 target dstAddrPort := netip.AddrPortFrom(dialIP, reqDetails.LocalPort) @@ -1509,14 +1573,26 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) { // Applications might be setting this on a forwarded connection, but from // userspace we can not see those, so the best we can do is to always // perform them with conservative timing. - // TODO(tailscale/tailscale#4522): Netstack defaults match the Linux - // defaults, and results in a little over two hours before the socket would - // be closed due to keepalive. A shorter default might be better, or seeking - // a default from the host IP stack. This also might be a useful - // user-tunable, as in userspace mode this can have broad implications such - // as lingering connections to fork style daemons. On the other side of the - // fence, the long duration timers are low impact values for battery powered - // peers. + // Netstack defaults match the Linux defaults and result in a little over + // two hours before the socket is closed due to keepalive. Operators can + // shorten the timers with TS_NETSTACK_KEEPALIVE_IDLE and + // TS_NETSTACK_KEEPALIVE_INTERVAL (see netstackKeepaliveIdle); the + // defaults are left unchanged because the long timers are low-impact for + // battery-powered peers and this has broad implications in userspace + // mode (lingering connections to fork-style daemons, etc). See + // tailscale/tailscale#4522. + if d := netstackKeepaliveIdle(); d > 0 { + idle := tcpip.KeepaliveIdleOption(d) + if err := ep.SetSockOpt(&idle); err != nil { + ns.logf("netstack: SetSockOpt(KeepaliveIdle=%v) failed: %v", d, err) + } + } + if d := netstackKeepaliveInterval(); d > 0 { + intvl := tcpip.KeepaliveIntervalOption(d) + if err := ep.SetSockOpt(&intvl); err != nil { + ns.logf("netstack: SetSockOpt(KeepaliveInterval=%v) failed: %v", d, err) + } + } ep.SocketOptions().SetKeepAlive(true) // This function is called when we're ready to use the @@ -1585,12 +1661,30 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) { } else { dialIP = ipv4Loopback } + case hittingServiceIP: + // TCP to the Tailscale service IP on a port we don't serve + // (anything other than DNS/53, web client/80, Taildrive/8080, + // or the debug loopback port handled above). handleLocalPackets + // absorbs all quad-100 traffic into netstack to prevent it + // from leaking to WireGuard peers as noisy "open-conn-track: + // timeout opening ...; no associated peer node" log lines + // (see the comment there). + // + // Without this explicit guard, execution would fall through + // to the isTailscaleIP case below (quad-100 is in the + // tailscale IP range), rewriting the dial target to + // 127.0.0.1: and forwardTCP'ing the connection onto + // whatever random service happens to be listening on the + // host's loopback at that port. Reject cleanly with a RST + // here instead. + r.Complete(true) // sends a RST + return case isTailscaleIP: dialIP = ipv4Loopback } dialAddr := netip.AddrPortFrom(dialIP, uint16(reqDetails.LocalPort)) - if !ns.forwardTCP(getConnOrReset, clientRemoteIP, &wq, dialAddr) { + if !ns.forwardTCP(getConnOrReset, clientRemoteIP, &wq, dialAddr, isLocal) { r.Complete(true) // sends a RST } } @@ -1602,7 +1696,7 @@ type tcpCloser interface { CloseWrite() error } -func (ns *Impl) forwardTCP(getClient func(...tcpip.SettableSocketOption) *gonet.TCPConn, clientRemoteIP netip.Addr, wq *waiter.Queue, dialAddr netip.AddrPort) (handled bool) { +func (ns *Impl) forwardTCP(getClient func(...tcpip.SettableSocketOption) *gonet.TCPConn, clientRemoteIP netip.Addr, wq *waiter.Queue, dialAddr netip.AddrPort, isLocal bool) (handled bool) { dialAddrStr := dialAddr.String() if debugNetstack() { ns.logf("[v2] netstack: forwarding incoming connection to %s", dialAddrStr) @@ -1649,11 +1743,13 @@ func (ns *Impl) forwardTCP(getClient func(...tcpip.SettableSocketOption) *gonet. backendLocalAddr := backend.LocalAddr().(*net.TCPAddr) backendLocalIPPort := netaddr.Unmap(backendLocalAddr.AddrPort()) - if err := ns.pm.RegisterIPPortIdentity("tcp", backendLocalIPPort, clientRemoteIP); err != nil { - ns.logf("netstack: could not register TCP mapping %s: %v", backendLocalIPPort, err) - return + if isLocal { + if err := ns.pm.RegisterIPPortIdentity("tcp", backendLocalIPPort, clientRemoteIP); err != nil { + ns.logf("netstack: could not register TCP mapping %s: %v", backendLocalIPPort, err) + return + } + defer ns.pm.UnregisterIPPortIdentity("tcp", backendLocalIPPort) } - defer ns.pm.UnregisterIPPortIdentity("tcp", backendLocalIPPort) // If we get here, either the getClient call below will succeed and // return something we can Close, or it will fail and will properly diff --git a/wgengine/netstack/netstack_test.go b/wgengine/netstack/netstack_test.go index da262fc13acbd..4f920c8e0271f 100644 --- a/wgengine/netstack/netstack_test.go +++ b/wgengine/netstack/netstack_test.go @@ -33,6 +33,7 @@ import ( "tailscale.com/types/ipproto" "tailscale.com/types/logid" "tailscale.com/types/netmap" + "tailscale.com/util/clientmetric" "tailscale.com/wgengine" "tailscale.com/wgengine/filter" ) @@ -453,6 +454,194 @@ func TestShouldProcessInbound(t *testing.T) { }, want: false, }, + { + name: "udp-on-service-vip-with-listener-ipv4", + pkt: &packet.Parsed{ + IPVersion: 4, + IPProto: ipproto.UDP, + Src: netip.MustParseAddrPort("100.101.102.103:1234"), + Dst: netip.MustParseAddrPort("100.100.100.100:8080"), + }, + beforeStart: func(i *Impl) { + i.ProcessLocalIPs = false + i.ProcessSubnets = false + }, + afterStart: func(i *Impl) { + IPServiceMap := netmap.IPServiceMappings{ + serviceIP: "svc:test-service", + } + i.lb.SetIPServiceMappingsForTest(IPServiceMap) + + i.atomicIsVIPServiceIPFunc.Store(func(addr netip.Addr) bool { + return addr == serviceIP + }) + + // Register the service VIP address on the NIC so gVisor can route to it + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: tcpip.AddrFrom4(serviceIP.As4()).WithPrefix(), + } + + if err := i.ipstack.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress: %v", err) + } + + // Create a UDP listener on the service VIP + pc, err := gonet.DialUDP(i.ipstack, &tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.AddrFrom4(serviceIP.As4()), + Port: 8080, + }, nil, header.IPv4ProtocolNumber) + if err != nil { + t.Fatalf("DialUDP: %v", err) + } + t.Cleanup(func() { pc.Close() }) + + i.atomicIsLocalIPFunc.Store(looksLikeATailscaleSelfAddress) + }, + want: true, + }, + { + name: "udp-on-service-vip-no-listener-ipv4", + pkt: &packet.Parsed{ + IPVersion: 4, + IPProto: ipproto.UDP, + Src: netip.MustParseAddrPort("100.101.102.103:1234"), + Dst: netip.MustParseAddrPort("100.100.100.100:9999"), + }, + beforeStart: func(i *Impl) { + i.ProcessLocalIPs = false + i.ProcessSubnets = false + }, + afterStart: func(i *Impl) { + IPServiceMap := netmap.IPServiceMappings{ + serviceIP: "svc:test-service", + } + i.lb.SetIPServiceMappingsForTest(IPServiceMap) + + i.atomicIsVIPServiceIPFunc.Store(func(addr netip.Addr) bool { + return addr == serviceIP + }) + + i.atomicIsLocalIPFunc.Store(looksLikeATailscaleSelfAddress) + }, + want: false, + }, + { + name: "udp-on-service-vip-with-listener-ipv6", + pkt: &packet.Parsed{ + IPVersion: 6, + IPProto: ipproto.UDP, + Src: netip.MustParseAddrPort("[fd7a:115c:a1e0::1]:1234"), + Dst: netip.MustParseAddrPort("[fd7a:115c:a1e0::53]:8080"), + }, + beforeStart: func(i *Impl) { + i.ProcessLocalIPs = false + i.ProcessSubnets = false + }, + afterStart: func(i *Impl) { + IPServiceMap := netmap.IPServiceMappings{ + serviceIPv6: "svc:test-service", + } + i.lb.SetIPServiceMappingsForTest(IPServiceMap) + + i.atomicIsVIPServiceIPFunc.Store(func(addr netip.Addr) bool { + return addr == serviceIPv6 + }) + + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: tcpip.AddrFrom16(serviceIPv6.As16()).WithPrefix(), + } + if err := i.ipstack.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress: %v", err) + } + + pc, err := gonet.DialUDP(i.ipstack, &tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.AddrFrom16(serviceIPv6.As16()), + Port: 8080, + }, nil, header.IPv6ProtocolNumber) + if err != nil { + t.Fatalf("DialUDP: %v", err) + } + t.Cleanup(func() { pc.Close() }) + + i.atomicIsLocalIPFunc.Store(looksLikeATailscaleSelfAddress) + }, + want: true, + }, + { + name: "udp-on-service-vip-no-listener-ipv6", + pkt: &packet.Parsed{ + IPVersion: 6, + IPProto: ipproto.UDP, + Src: netip.MustParseAddrPort("[fd7a:115c:a1e0::1]:1234"), + Dst: netip.AddrPortFrom(serviceIPv6, 9999), + }, + beforeStart: func(i *Impl) { + i.ProcessLocalIPs = false + i.ProcessSubnets = false + }, + afterStart: func(i *Impl) { + IPServiceMap := netmap.IPServiceMappings{ + serviceIPv6: "svc:test-service", + } + i.lb.SetIPServiceMappingsForTest(IPServiceMap) + + i.atomicIsVIPServiceIPFunc.Store(func(addr netip.Addr) bool { + return addr == serviceIPv6 + }) + + i.atomicIsLocalIPFunc.Store(looksLikeATailscaleSelfAddress) + }, + want: false, + }, + { + name: "tcp-on-service-vip-with-udp-listener", + pkt: &packet.Parsed{ + IPVersion: 4, + IPProto: ipproto.TCP, + Src: netip.MustParseAddrPort("100.101.102.103:1234"), + Dst: netip.MustParseAddrPort("100.100.100.100:8080"), // serviceIP + TCPFlags: packet.TCPSyn, + }, + beforeStart: func(i *Impl) { + i.ProcessLocalIPs = false + i.ProcessSubnets = false + }, + afterStart: func(i *Impl) { + IPServiceMap := netmap.IPServiceMappings{ + serviceIP: "svc:test-service", + } + i.lb.SetIPServiceMappingsForTest(IPServiceMap) + + i.atomicIsVIPServiceIPFunc.Store(func(addr netip.Addr) bool { + return addr == serviceIP + }) + + protocolAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: tcpip.AddrFrom4(serviceIP.As4()).WithPrefix(), + } + if err := i.ipstack.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil { + t.Fatalf("AddProtocolAddress: %v", err) + } + + pc, err := gonet.DialUDP(i.ipstack, &tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.AddrFrom4(serviceIP.As4()), + Port: 8080, + }, nil, header.IPv4ProtocolNumber) + if err != nil { + t.Fatalf("DialUDP: %v", err) + } + t.Cleanup(func() { pc.Close() }) + + i.atomicIsLocalIPFunc.Store(looksLikeATailscaleSelfAddress) + }, + want: false, + }, // TODO(andrew): test PeerAPI // TODO(andrew): test TCP packets without the SYN flag set @@ -623,11 +812,15 @@ func TestTCPForwardLimits(t *testing.T) { // TestTCPForwardLimits_PerClient verifies that the per-client limit for TCP // forwarding works. func TestTCPForwardLimits_PerClient(t *testing.T) { + clientmetric.ResetForTest(t) + tstest.AssertNotParallel(t) // calls envknob.Setenv envknob.Setenv("TS_DEBUG_NETSTACK", "true") // Set our test override limits during this test. - tstest.Replace(t, &maxInFlightConnectionAttemptsForTest, 2) - tstest.Replace(t, &maxInFlightConnectionAttemptsPerClientForTest, 1) + maxInFlightConnectionAttemptsForTest.Store(2) + t.Cleanup(func() { maxInFlightConnectionAttemptsForTest.Store(0) }) + maxInFlightConnectionAttemptsPerClientForTest.Store(1) + t.Cleanup(func() { maxInFlightConnectionAttemptsPerClientForTest.Store(0) }) impl := makeNetstack(t, func(impl *Impl) { impl.ProcessSubnets = true @@ -760,6 +953,7 @@ func TestHandleLocalPackets(t *testing.T) { impl.lb.SetIPServiceMappingsForTest(IPServiceMap) t.Run("ShouldHandleServiceIP", func(t *testing.T) { + t.Parallel() pkt := &packet.Parsed{ IPVersion: 4, IPProto: ipproto.TCP, @@ -772,7 +966,94 @@ func TestHandleLocalPackets(t *testing.T) { t.Errorf("got filter outcome %v, want filter.DropSilently", resp) } }) + // Any port on the quad-100 service IP must be absorbed locally by + // netstack and never leak out to the WireGuard / peer-routing + // layers. Historically we only intercepted specific ports (UDP 53 + // and TCP 53/80/8080), causing stray traffic to other ports such + // as 100.100.100.100:853 (DoT) to time out in wireguard-go and + // produce "open-conn-track: timeout opening ...; no associated + // peer node" log spam. See the handleLocalPackets comment. + quad100LeakCases := []struct { + name string + proto ipproto.Proto + dst string + }{ + {"TCP-853-DoT-v4", ipproto.TCP, "100.100.100.100:853"}, + {"TCP-443-DoH-v4", ipproto.TCP, "100.100.100.100:443"}, + {"TCP-9000-stray-v4", ipproto.TCP, "100.100.100.100:9000"}, + {"UDP-853-DoQ-v4", ipproto.UDP, "100.100.100.100:853"}, + {"UDP-443-v4", ipproto.UDP, "100.100.100.100:443"}, + {"TCP-853-DoT-v6", ipproto.TCP, "[fd7a:115c:a1e0::53]:853"}, + {"UDP-443-v6", ipproto.UDP, "[fd7a:115c:a1e0::53]:443"}, + } + for _, tc := range quad100LeakCases { + t.Run("ShouldNotLeakQuad100_"+tc.name, func(t *testing.T) { + t.Parallel() + dst := netip.MustParseAddrPort(tc.dst) + ipVersion := uint8(4) + if dst.Addr().Is6() { + ipVersion = 6 + } + src := "127.0.0.1:9999" + if ipVersion == 6 { + src = "[::1]:9999" + } + pkt := &packet.Parsed{ + IPVersion: ipVersion, + IPProto: tc.proto, + Src: netip.MustParseAddrPort(src), + Dst: dst, + } + if tc.proto == ipproto.TCP { + pkt.TCPFlags = packet.TCPSyn + } + resp, _ := impl.handleLocalPackets(pkt, impl.tundev, nil) + if resp != filter.DropSilently { + t.Errorf("quad-100 %s packet leaked: got filter outcome %v, want filter.DropSilently", tc.name, resp) + } + }) + } + // Exhaustive sweep of all ports for both transport protocols and + // both IP versions, confirming no port leaks. The quad-100 branch + // of handleLocalPackets is port-independent by construction; this + // test serves as a regression guard against accidental port-based + // exemptions slipping back in. + t.Run("ShouldNotLeakQuad100_AllPorts", func(t *testing.T) { + t.Parallel() + protos := []ipproto.Proto{ipproto.TCP, ipproto.UDP} + dsts := []netip.Addr{ + netip.MustParseAddr("100.100.100.100"), + netip.MustParseAddr("fd7a:115c:a1e0::53"), + } + for _, proto := range protos { + for _, dstAddr := range dsts { + ipVersion := uint8(4) + srcStr := "127.0.0.1:9999" + if dstAddr.Is6() { + ipVersion = 6 + srcStr = "[::1]:9999" + } + src := netip.MustParseAddrPort(srcStr) + for port := 1; port <= 65535; port++ { + pkt := &packet.Parsed{ + IPVersion: ipVersion, + IPProto: proto, + Src: src, + Dst: netip.AddrPortFrom(dstAddr, uint16(port)), + } + if proto == ipproto.TCP { + pkt.TCPFlags = packet.TCPSyn + } + resp, _ := impl.handleLocalPackets(pkt, impl.tundev, nil) + if resp != filter.DropSilently { + t.Fatalf("port=%d proto=%v dst=%v: got %v, want filter.DropSilently", port, proto, dstAddr, resp) + } + } + } + } + }) t.Run("ShouldHandle4via6", func(t *testing.T) { + t.Parallel() pkt := &packet.Parsed{ IPVersion: 6, IPProto: ipproto.TCP, @@ -795,6 +1076,7 @@ func TestHandleLocalPackets(t *testing.T) { } }) t.Run("ShouldHandleLocalTailscaleServices", func(t *testing.T) { + t.Parallel() pkt := &packet.Parsed{ IPVersion: 4, IPProto: ipproto.TCP, @@ -808,6 +1090,7 @@ func TestHandleLocalPackets(t *testing.T) { } }) t.Run("OtherNonHandled", func(t *testing.T) { + t.Parallel() pkt := &packet.Parsed{ IPVersion: 6, IPProto: ipproto.TCP, @@ -830,6 +1113,100 @@ func TestHandleLocalPackets(t *testing.T) { }) } +// TestQuad100UnservedTCPPortDoesNotForward verifies that a TCP SYN to the +// Tailscale service IP (100.100.100.100) on a port we don't serve is +// absorbed by netstack and rejected cleanly, without triggering the +// outbound forwardTCP dialer. +// +// handleLocalPackets now absorbs all quad-100 traffic regardless of +// port to prevent it leaking to WireGuard peers (which produced noisy +// "open-conn-track: timeout opening ...; no associated peer node" log +// lines). That leaves acceptTCP responsible for rejecting connections +// to ports we don't handle; without an explicit guard, execution would +// fall through to the isTailscaleIP case (quad-100 is in the tailscale +// range), rewriting the dial target to 127.0.0.1: and forwarding +// the connection to whatever random service happened to be listening +// on the host's loopback at that port. +// +// This test asserts that the forward dialer is NOT invoked for quad-100 +// SYNs on unserved ports; the guard in acceptTCP must RST instead. +func TestQuad100UnservedTCPPortDoesNotForward(t *testing.T) { + impl := makeNetstack(t, func(impl *Impl) { + impl.ProcessSubnets = false + impl.ProcessLocalIPs = false + impl.atomicIsLocalIPFunc.Store(looksLikeATailscaleSelfAddress) + }) + + dialFn, gotConn := makeHangDialer(t) + impl.forwardDialFunc = dialFn + + // Use a client IP in the CGNAT range so shouldProcessInbound-adjacent + // code paths treat this as plausibly-peer-sourced traffic, matching + // what a real stray quad-100 probe from the host OS would look like. + client := netip.MustParseAddr("100.101.102.103") + quad100 := tsaddr.TailscaleServiceIP() + + // 853 is DoT, the specific case called out in the original bug + // report ("conntrack error no peer found for 100.100.100.100:853"). + // Before the fix, port 853 (and any non-{53,80,8080} port) leaked + // out to WireGuard; after the fix it is absorbed here and must NOT + // trigger forwardTCP. + pkt := tcp4syn(t, client, quad100, 1234, 853) + var parsed packet.Parsed + parsed.Decode(pkt) + + resp, _ := impl.handleLocalPackets(&parsed, impl.tundev, nil) + if resp != filter.DropSilently { + t.Fatalf("handleLocalPackets for quad-100:853: got %v, want filter.DropSilently", resp) + } + + // acceptTCP runs asynchronously in the gVisor TCP dispatcher after + // handleLocalPackets injects the packet into netstack. Use the + // in-flight connection counter as a deterministic synchronization + // point: wrapTCPProtocolHandler increments connsInFlightByClient + // when the dispatcher hands the connection off to acceptTCP, and + // acceptTCP's deferred decrementInFlightTCPForward decrements it + // on return. + // + // On the green path (RST guard fires), acceptTCP returns promptly + // and the counter reaches 0. On the red path (fall-through to + // forwardTCP), acceptTCP blocks inside the forwardDialFunc call — + // makeHangDialer signals gotConn on entry (buffered, non-blocking) + // and then blocks forever — so the counter never reaches 0 but + // gotConn fires synchronously from the dispatcher goroutine. A + // select on both races those outcomes without real-time padding. + // + // testing/synctest is not usable here: gVisor's sleep package calls + // the runtime's gopark directly rather than via the standard + // library, so synctest.Wait() cannot observe those goroutines + // becoming durably blocked and hangs indefinitely. + inFlightZero := make(chan struct{}) + go func() { + for { + impl.mu.Lock() + n := impl.connsInFlightByClient[client] + impl.mu.Unlock() + if n == 0 { + close(inFlightZero) + return + } + time.Sleep(time.Millisecond) + } + }() + + select { + case <-gotConn: + t.Fatalf("forwardDialFunc was called for quad-100:853; acceptTCP fell through to forwardTCP instead of sending RST. This means stray traffic to quad-100 on unserved ports is being redirected to the host's loopback at the same port.") + case <-inFlightZero: + // acceptTCP returned cleanly; the RST guard fired. + case <-time.After(5 * time.Second): + // Safety net so a regression in the in-flight counter plumbing + // doesn't hang the whole test run; both outcomes above should + // fire within milliseconds in practice. + t.Fatal("timed out waiting for acceptTCP to dispatch quad-100:853 SYN") + } +} + func TestShouldSendToHost(t *testing.T) { var ( selfIP4 = netip.MustParseAddr("100.64.1.2") diff --git a/wgengine/netstack/netstack_userping_apple.go b/wgengine/netstack/netstack_userping_apple.go index a82b81e99e827..cb6926f0a9a4b 100644 --- a/wgengine/netstack/netstack_userping_apple.go +++ b/wgengine/netstack/netstack_userping_apple.go @@ -6,33 +6,30 @@ package netstack import ( + "context" + "net" "net/netip" "time" - probing "github.com/prometheus-community/pro-bing" + "tailscale.com/net/ping" ) // sendOutboundUserPing sends a non-privileged ICMP (or ICMPv6) ping to dstIP with the given timeout. func (ns *Impl) sendOutboundUserPing(dstIP netip.Addr, timeout time.Duration) error { - p, err := probing.NewPinger(dstIP.String()) - if err != nil { - ns.logf("sendICMPPingToIP failed to create pinger: %v", err) - return err - } + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() - p.Timeout = timeout - p.Count = 1 - p.SetPrivileged(false) + p := ping.New(ctx, ns.logf, nil) + p.Unprivileged = true + defer p.Close() - p.OnSend = func(pkt *probing.Packet) { - ns.logf("sendICMPPingToIP: forwarding ping to %s:", p.Addr()) - } - p.OnRecv = func(pkt *probing.Packet) { - ns.logf("sendICMPPingToIP: %d bytes pong from %s: icmp_seq=%d time=%v", pkt.Nbytes, pkt.IPAddr, pkt.Seq, pkt.Rtt) - } - p.OnFinish = func(stats *probing.Statistics) { - ns.logf("sendICMPPingToIP: done, %d replies received", stats.PacketsRecv) + dst := &net.IPAddr{IP: dstIP.AsSlice(), Zone: dstIP.Zone()} + ns.logf("sendOutboundUserPing: forwarding ping to %s", dstIP) + d, err := p.Send(ctx, dst, []byte("tailscale-userping")) + if err != nil { + ns.logf("sendOutboundUserPing: ping to %s failed: %v", dstIP, err) + return err } - - return p.Run() + ns.logf("sendOutboundUserPing: pong from %s in %v", dstIP, d) + return nil } diff --git a/wgengine/pendopen.go b/wgengine/pendopen.go index 77cb4a7b9b451..e816506def871 100644 --- a/wgengine/pendopen.go +++ b/wgengine/pendopen.go @@ -101,8 +101,8 @@ var ( appleIPRange = netip.MustParsePrefix("17.0.0.0/8") canonicalIPs = sync.OnceValue(func() (checkIPFunc func(netip.Addr) bool) { // https://bgp.he.net/AS41231#_prefixes - t := &bart.Table[bool]{} - for _, s := range strings.Fields(` + t := &bart.Lite{} + for s := range strings.FieldsSeq(` 91.189.89.0/24 91.189.91.0/24 91.189.92.0/24 @@ -115,12 +115,9 @@ var ( 185.125.188.0/23 185.125.190.0/24 194.169.254.0/24`) { - t.Insert(netip.MustParsePrefix(s), true) - } - return func(ip netip.Addr) bool { - v, _ := t.Lookup(ip) - return v + t.Insert(netip.MustParsePrefix(s)) } + return t.Contains }) ) diff --git a/wgengine/router/osrouter/router_linux.go b/wgengine/router/osrouter/router_linux.go index 3c261c9120785..73f65cdf149c1 100644 --- a/wgengine/router/osrouter/router_linux.go +++ b/wgengine/router/osrouter/router_linux.go @@ -89,6 +89,7 @@ type linuxRouter struct { connmarkEnabled bool // whether connmark rules are currently enabled netfilterMode preftype.NetfilterMode netfilterKind string + cgnatMode linuxfw.CGNATMode magicsockPortV4 uint16 magicsockPortV6 uint16 } @@ -489,7 +490,9 @@ func (r *linuxRouter) Set(cfg *router.Config) error { // Connmark rules for rp_filter compatibility. // Always enabled when netfilter is ON to handle all rp_filter=1 scenarios // (normal operation, exit nodes, subnet routers, and clients using exit nodes). - netfilterOn := cfg.NetfilterMode == netfilterOn + // Gate on r.netfilterMode (actual state) rather than cfg.NetfilterMode + // (desired state) so we don't call into the runner when chain setup failed. + netfilterOn := r.netfilterMode == netfilterOn switch { case netfilterOn == r.connmarkEnabled: // state already correct, nothing to do. @@ -502,6 +505,14 @@ func (r *linuxRouter) Set(cfg *router.Config) error { // Only update state on success to keep it in sync with actual rules r.connmarkEnabled = true } + // Enable src_valid_mark so the kernel uses the packet's fwmark + // during the rp_filter reverse-path check. Without this, the + // connmark restore in mangle/PREROUTING is ineffective — rp_filter + // does its routing lookup with fwmark=0, ignoring the restored + // bypass mark, and drops reply packets as martians. + if err := writeSysctl("net.ipv4.conf.all.src_valid_mark", "1"); err != nil { + r.logf("warning: failed to enable src_valid_mark: %v", err) + } default: r.logf("disabling connmark-based rp_filter workaround") if err := r.nfr.DelConnmarkSaveRule(); err != nil { @@ -521,9 +532,50 @@ func (r *linuxRouter) Set(cfg *router.Config) error { r.enableIPForwarding() } + // Remove the rule to drop off-tailnet CGNAT traffic, if needed. + if netfilterOn || r.netfilterMode == netfilterNoDivert { + var cgnatMode linuxfw.CGNATMode + if cfg.RemoveCGNATDropRule { + cgnatMode = linuxfw.CGNATModeReturn + } else { + cgnatMode = linuxfw.CGNATModeDrop + } + err := r.setCGNATDropModeLocked(cgnatMode) + if err != nil { + errs = append(errs, fmt.Errorf("set cgnat mode: %w", err)) + } + } + return errors.Join(errs...) } +// setCGNATDropModeLocked clears old rules and add new rules for the desired +// behavior for incoming non-Tailscale CGNAT packets. +// [linuxRouter.mu] must be held. +func (r *linuxRouter) setCGNATDropModeLocked(want linuxfw.CGNATMode) error { + if want == r.cgnatMode { + return nil + } + // r.cgnatMode is empty at initial startup, before this function has been + // called for the first time. In that case, we can skip deleting old + // rules, because there aren't any. + if r.cgnatMode != "" { + err := r.nfr.DelExternalCGNATRules(r.cgnatMode, r.tunname) + if err != nil { + return fmt.Errorf("clear old cgnat rules: %w", err) + } + } + err := r.nfr.AddExternalCGNATRules(want, r.tunname) + if err != nil { + // We currently have no rules set, so change the state to reflect that + // so we might try again on a future Router update. + r.cgnatMode = "" + return fmt.Errorf("add new cgnat rules: %w", err) + } + r.cgnatMode = want + return nil +} + var dockerStatefulFilteringWarnable = health.Register(&health.Warnable{ Code: "docker-stateful-filtering", Title: "Docker with stateful filtering", @@ -772,6 +824,20 @@ func (r *linuxRouter) setNetfilterModeLocked(mode preftype.NetfilterMode) error } } + // Re-add the CGNAT rules if we had any set. + // This does not call [linuxRouter.setCGNATDropModeLocked] because that + // function assumes that [linuxRouter.cgnatMode] accurately represents the + // current state in the firewall. This would not be true when we hit this + // code path, and is what we're fixing up here. + if r.cgnatMode != "" { + if err := r.nfr.AddExternalCGNATRules(r.cgnatMode, r.tunname); err != nil { + // We currently have no rules set, so change the state to reflect that + // so we might try again on a future Router update. + r.cgnatMode = "" + return fmt.Errorf("add cgnat rules: %w", err) + } + } + return nil } diff --git a/wgengine/router/osrouter/router_linux_test.go b/wgengine/router/osrouter/router_linux_test.go index bae997e331d55..340ebb1486a0f 100644 --- a/wgengine/router/osrouter/router_linux_test.go +++ b/wgengine/router/osrouter/router_linux_test.go @@ -54,13 +54,13 @@ ip rule add -6 pref 5270 table 52 want string }{ { - name: "no config", + name: "no-config", in: nil, want: ` up` + basic, }, { - name: "local addr only", + name: "local-addr-only", in: &Config{ LocalAddrs: mustCIDRs("100.101.102.103/10"), NetfilterMode: netfilterOff, @@ -71,7 +71,7 @@ ip addr add 100.101.102.103/10 dev tailscale0` + basic, }, { - name: "addr and routes", + name: "addr-and-routes", in: &Config{ LocalAddrs: mustCIDRs("100.101.102.103/10"), Routes: mustCIDRs("100.100.100.100/32", "192.168.16.0/24"), @@ -85,7 +85,7 @@ ip route add 192.168.16.0/24 dev tailscale0 table 52` + basic, }, { - name: "addr and routes and subnet routes", + name: "addr-routes-and-subnet-routes", in: &Config{ LocalAddrs: mustCIDRs("100.101.102.103/10"), Routes: mustCIDRs("100.100.100.100/32", "192.168.16.0/24"), @@ -100,7 +100,7 @@ ip route add 192.168.16.0/24 dev tailscale0 table 52` + basic, }, { - name: "addr and routes and subnet routes with netfilter", + name: "addr-routes-subnet-routes-with-netfilter", in: &Config{ LocalAddrs: mustCIDRs("100.101.102.104/10"), Routes: mustCIDRs("100.100.100.100/32", "10.0.0.0/8"), @@ -141,7 +141,7 @@ v6/nat/ts-postrouting -m mark --mark 0x40000/0xff0000 -j MASQUERADE `, }, { - name: "addr and routes and subnet routes with netfilter but no stateful filtering", + name: "addr-routes-subnet-routes-netfilter-no-stateful", in: &Config{ LocalAddrs: mustCIDRs("100.101.102.104/10"), Routes: mustCIDRs("100.100.100.100/32", "10.0.0.0/8"), @@ -180,7 +180,7 @@ v6/nat/ts-postrouting -m mark --mark 0x40000/0xff0000 -j MASQUERADE `, }, { - name: "addr and routes with netfilter", + name: "addr-and-routes-with-netfilter", in: &Config{ LocalAddrs: mustCIDRs("100.101.102.104/10"), Routes: mustCIDRs("100.100.100.100/32", "10.0.0.0/8"), @@ -215,7 +215,7 @@ v6/nat/POSTROUTING -j ts-postrouting }, { - name: "addr and routes and subnet routes with netfilter but no SNAT", + name: "addr-routes-subnet-routes-netfilter-no-SNAT", in: &Config{ LocalAddrs: mustCIDRs("100.101.102.104/10"), Routes: mustCIDRs("100.100.100.100/32", "10.0.0.0/8"), @@ -251,7 +251,7 @@ v6/nat/POSTROUTING -j ts-postrouting `, }, { - name: "addr and routes with netfilter", + name: "addr-and-routes-with-netfilter-2", in: &Config{ LocalAddrs: mustCIDRs("100.101.102.104/10"), Routes: mustCIDRs("100.100.100.100/32", "10.0.0.0/8"), @@ -286,7 +286,7 @@ v6/nat/POSTROUTING -j ts-postrouting }, { - name: "addr and routes with half netfilter", + name: "addr-and-routes-with-half-netfilter", in: &Config{ LocalAddrs: mustCIDRs("100.101.102.104/10"), Routes: mustCIDRs("100.100.100.100/32", "10.0.0.0/8"), @@ -310,7 +310,7 @@ v6/filter/ts-forward -o tailscale0 -j ACCEPT `, }, { - name: "addr and routes with netfilter2", + name: "addr-and-routes-with-netfilter2", in: &Config{ LocalAddrs: mustCIDRs("100.101.102.104/10"), Routes: mustCIDRs("100.100.100.100/32", "10.0.0.0/8"), @@ -344,7 +344,7 @@ v6/nat/POSTROUTING -j ts-postrouting `, }, { - name: "addr, routes, and local routes with netfilter", + name: "addr-routes-local-routes-with-netfilter", in: &Config{ LocalAddrs: mustCIDRs("100.101.102.104/10"), Routes: mustCIDRs("100.100.100.100/32", "0.0.0.0/0"), @@ -380,7 +380,7 @@ v6/nat/POSTROUTING -j ts-postrouting `, }, { - name: "addr, routes, and local routes with no netfilter", + name: "addr-routes-local-routes-no-netfilter", in: &Config{ LocalAddrs: mustCIDRs("100.101.102.104/10"), Routes: mustCIDRs("100.100.100.100/32", "0.0.0.0/0"), @@ -396,7 +396,7 @@ ip route add throw 10.0.0.0/8 table 52 ip route add throw 192.168.0.0/24 table 52` + basic, }, { - name: "subnet routes with connmark for rp_filter", + name: "subnet-routes-connmark-for-rp_filter", in: &Config{ LocalAddrs: mustCIDRs("100.101.102.104/10"), Routes: mustCIDRs("100.100.100.100/32"), @@ -433,7 +433,7 @@ v6/nat/ts-postrouting -m mark --mark 0x40000/0xff0000 -j MASQUERADE `, }, { - name: "subnet routes (connmark always enabled)", + name: "subnet-routes-connmark-always-enabled", in: &Config{ LocalAddrs: mustCIDRs("100.101.102.104/10"), Routes: mustCIDRs("100.100.100.100/32"), @@ -470,7 +470,7 @@ v6/nat/ts-postrouting -m mark --mark 0x40000/0xff0000 -j MASQUERADE `, }, { - name: "connmark with stateful filtering", + name: "connmark-with-stateful-filtering", in: &Config{ LocalAddrs: mustCIDRs("100.101.102.104/10"), Routes: mustCIDRs("100.100.100.100/32"), @@ -562,6 +562,10 @@ type fakeIPTablesRunner struct { ipt4 map[string][]string ipt6 map[string][]string // we always assume ipv6 and ipv6 nat are enabled when testing + + addChainsErr error // if non-nil, AddChains returns it instead of setting up chains + addConnmarkSaveCalls int + addExternalCGNATCalls int } func newIPTablesRunner(t *testing.T) linuxfw.NetfilterRunner { @@ -717,11 +721,11 @@ func (n *fakeIPTablesRunner) DeleteDNATRuleForSvc(svcName string, origDst, dst n return errors.New("not implemented") } +type iptRule struct{ chain, rule string } + func (n *fakeIPTablesRunner) addBase4(tunname string) error { curIPT := n.ipt4 - newRules := []struct{ chain, rule string }{ - {"filter/ts-input", fmt.Sprintf("! -i %s -s %s -j RETURN", tunname, tsaddr.ChromeOSVMRange().String())}, - {"filter/ts-input", fmt.Sprintf("! -i %s -s %s -j DROP", tunname, tsaddr.CGNATRange().String())}, + newRules := []iptRule{ {"filter/ts-forward", fmt.Sprintf("-i %s -j MARK --set-mark %s/%s", tunname, tsconst.LinuxSubnetRouteMark, tsconst.LinuxFwmarkMask)}, {"filter/ts-forward", fmt.Sprintf("-m mark --mark %s/%s -j ACCEPT", tsconst.LinuxSubnetRouteMark, tsconst.LinuxFwmarkMask)}, {"filter/ts-forward", fmt.Sprintf("-o %s -s %s -j DROP", tunname, tsaddr.CGNATRange().String())}, @@ -737,7 +741,7 @@ func (n *fakeIPTablesRunner) addBase4(tunname string) error { func (n *fakeIPTablesRunner) addBase6(tunname string) error { curIPT := n.ipt6 - newRules := []struct{ chain, rule string }{ + newRules := []iptRule{ {"filter/ts-forward", fmt.Sprintf("-i %s -j MARK --set-mark %s/%s", tunname, tsconst.LinuxSubnetRouteMark, tsconst.LinuxFwmarkMask)}, {"filter/ts-forward", fmt.Sprintf("-m mark --mark %s/%s -j ACCEPT", tsconst.LinuxSubnetRouteMark, tsconst.LinuxFwmarkMask)}, {"filter/ts-forward", fmt.Sprintf("-o %s -j ACCEPT", tunname)}, @@ -762,7 +766,7 @@ func (n *fakeIPTablesRunner) DelLoopbackRule(addr netip.Addr) error { } func (n *fakeIPTablesRunner) AddHooks() error { - newRules := []struct{ chain, rule string }{ + newRules := []iptRule{ {"filter/INPUT", "-j ts-input"}, {"filter/FORWARD", "-j ts-forward"}, {"nat/POSTROUTING", "-j ts-postrouting"}, @@ -778,7 +782,7 @@ func (n *fakeIPTablesRunner) AddHooks() error { } func (n *fakeIPTablesRunner) DelHooks(logf logger.Logf) error { - delRules := []struct{ chain, rule string }{ + delRules := []iptRule{ {"filter/INPUT", "-j ts-input"}, {"filter/FORWARD", "-j ts-forward"}, {"nat/POSTROUTING", "-j ts-postrouting"}, @@ -794,6 +798,9 @@ func (n *fakeIPTablesRunner) DelHooks(logf logger.Logf) error { } func (n *fakeIPTablesRunner) AddChains() error { + if n.addChainsErr != nil { + return n.addChainsErr + } for _, ipt := range []map[string][]string{n.ipt4, n.ipt6} { for _, chain := range []string{"filter/ts-input", "filter/ts-forward", "nat/ts-postrouting"} { ipt[chain] = nil @@ -922,6 +929,7 @@ func (n *fakeIPTablesRunner) DelMagicsockPortRule(port uint16, network string) e } func (n *fakeIPTablesRunner) AddConnmarkSaveRule() error { + n.addConnmarkSaveCalls++ // PREROUTING rule: restore mark from conntrack prerouteRule := "-m conntrack --ctstate ESTABLISHED,RELATED -j CONNMARK --restore-mark --nfmask 0xff0000 --ctmask 0xff0000" for _, ipt := range []map[string][]string{n.ipt4, n.ipt6} { @@ -953,6 +961,49 @@ func (n *fakeIPTablesRunner) DelConnmarkSaveRule() error { return nil } +func buildExternalCGNATRules(mode linuxfw.CGNATMode, tunname string) ([]iptRule, error) { + switch mode { + case linuxfw.CGNATModeDrop: + return []iptRule{ + {"filter/ts-input", fmt.Sprintf("! -i %s -s %s -j RETURN", tunname, tsaddr.ChromeOSVMRange().String())}, + {"filter/ts-input", fmt.Sprintf("! -i %s -s %s -j DROP", tunname, tsaddr.CGNATRange().String())}, + }, nil + case linuxfw.CGNATModeReturn: + return []iptRule{ + {"filter/ts-input", fmt.Sprintf("! -i %s -s %s -j RETURN", tunname, tsaddr.CGNATRange().String())}, + }, nil + default: + return nil, fmt.Errorf("unsupported mode %q", mode) + } +} + +func (n *fakeIPTablesRunner) AddExternalCGNATRules(mode linuxfw.CGNATMode, tunname string) error { + n.addExternalCGNATCalls++ + rules, err := buildExternalCGNATRules(mode, tunname) + if err != nil { + return err + } + for _, rule := range rules { + if err := appendRule(n, n.ipt4, rule.chain, rule.rule); err != nil { + return fmt.Errorf("add rule %q to chain %q: %w", rule.rule, rule.chain, err) + } + } + return nil +} + +func (n *fakeIPTablesRunner) DelExternalCGNATRules(mode linuxfw.CGNATMode, tunname string) error { + rules, err := buildExternalCGNATRules(mode, tunname) + if err != nil { + return err + } + for _, rule := range rules { + if err := deleteRule(n, n.ipt4, rule.chain, rule.rule); err != nil { + return fmt.Errorf("del rule %q to chain %q: %w", rule.rule, rule.chain, err) + } + } + return nil +} + func (n *fakeIPTablesRunner) HasIPV6() bool { return true } func (n *fakeIPTablesRunner) HasIPV6NAT() bool { return true } func (n *fakeIPTablesRunner) HasIPV6Filter() bool { return true } @@ -1073,11 +1124,9 @@ func (o *fakeOS) run(args ...string) error { switch args[2] { case "add": - for _, el := range *ls { - if el == rest { - o.t.Errorf("can't add %q, already present", rest) - return errors.New("already exists") - } + if slices.Contains(*ls, rest) { + o.t.Errorf("can't add %q, already present", rest) + return errors.New("already exists") } *ls = append(*ls, rest) sort.Strings(*ls) @@ -1159,9 +1208,7 @@ func (lt *linuxTest) Close() error { } func newLinuxRootTest(t *testing.T) (*linuxTest, *eventbus.Bus) { - if os.Getuid() != 0 { - t.Skip("test requires root") - } + tstest.RequireRoot(t) lt := new(linuxTest) lt.tun = createTestTUN(t) @@ -1213,7 +1260,9 @@ func TestRuleDeletedEvent(t *testing.T) { } func TestDelRouteIdempotent(t *testing.T) { + fake := NewFakeOS(t) lt, _ := newLinuxRootTest(t) + lt.r.nfr = fake.nfr defer lt.Close() for _, s := range []string{ @@ -1239,7 +1288,9 @@ func TestDelRouteIdempotent(t *testing.T) { } func TestAddRemoveRules(t *testing.T) { + fake := NewFakeOS(t) lt, _ := newLinuxRootTest(t) + lt.r.nfr = fake.nfr defer lt.Close() r := lt.r @@ -1508,3 +1559,53 @@ func TestUpdateMagicsockPortChange(t *testing.T) { oldPortRule, nfr.ipt4["filter/ts-input"]) } } + +// TestSetSkipsNetfilterAddonsWhenSetupFails verifies that Set does not invoke +// rule-management methods that depend on the ts-* chains existing when chain +// setup failed. +func TestSetSkipsNetfilterAddonsWhenSetupFails(t *testing.T) { + nfr := newIPTablesRunner(t).(*fakeIPTablesRunner) + nfr.addChainsErr = errors.New("kernel lacks netfilter support") + + bus := eventbus.New() + defer bus.Close() + mon, err := netmon.New(bus, logger.Discard) + if err != nil { + t.Fatal(err) + } + mon.Start() + defer mon.Close() + + fake := NewFakeOS(t) + ht := health.NewTracker(bus) + r, err := newUserspaceRouterAdvanced(logger.Discard, "tailscale0", mon, fake, ht, bus) + if err != nil { + t.Fatalf("newUserspaceRouterAdvanced: %v", err) + } + lr := r.(*linuxRouter) + lr.nfr = nfr + if err := lr.Up(); err != nil { + t.Fatalf("Up: %v", err) + } + defer lr.Close() + + cfg := &Config{ + LocalAddrs: mustCIDRs("100.101.102.103/10"), + NetfilterMode: netfilterOn, + } + // Set must return an error (chain setup failed) but must not panic. + if err := lr.Set(cfg); err == nil { + t.Fatal("Set returned nil; want error because AddChains failed") + } + if lr.netfilterMode != netfilterOff { + t.Errorf("netfilterMode = %v; want netfilterOff after failed AddChains", lr.netfilterMode) + } + if nfr.addConnmarkSaveCalls != 0 { + t.Errorf("AddConnmarkSaveRule called %d times; want 0 when chain setup failed", + nfr.addConnmarkSaveCalls) + } + if nfr.addExternalCGNATCalls != 0 { + t.Errorf("AddExternalCGNATRules called %d times; want 0 when chain setup failed", + nfr.addExternalCGNATCalls) + } +} diff --git a/wgengine/router/osrouter/runner.go b/wgengine/router/osrouter/runner.go index bdc710a8d369a..82b2680e67277 100644 --- a/wgengine/router/osrouter/runner.go +++ b/wgengine/router/osrouter/runner.go @@ -10,6 +10,7 @@ import ( "fmt" "os" "os/exec" + "slices" "strconv" "strings" "syscall" @@ -42,8 +43,7 @@ func errCode(err error) int { if err == nil { return 0 } - var e *exec.ExitError - if ok := errors.As(err, &e); ok { + if e, ok := errors.AsType[*exec.ExitError](err); ok { return e.ExitCode() } s := err.Error() @@ -96,12 +96,7 @@ func newRunGroup(okCode []int, runner commandRunner) *runGroup { func (rg *runGroup) okCode(err error) bool { got := errCode(err) - for _, want := range rg.OkCode { - if got == want { - return true - } - } - return false + return slices.Contains(rg.OkCode, got) } func (rg *runGroup) Output(args ...string) []byte { diff --git a/wgengine/router/router.go b/wgengine/router/router.go index 6868acb43ee2b..f8d702d470527 100644 --- a/wgengine/router/router.go +++ b/wgengine/router/router.go @@ -132,10 +132,11 @@ type Config struct { SubnetRoutes []netip.Prefix // Linux-only things below, ignored on other platforms. - SNATSubnetRoutes bool // SNAT traffic to local subnets - StatefulFiltering bool // Apply stateful filtering to inbound connections - NetfilterMode preftype.NetfilterMode // how much to manage netfilter rules - NetfilterKind string // what kind of netfilter to use ("nftables", "iptables", or "" to auto-detect) + SNATSubnetRoutes bool // SNAT traffic to local subnets + StatefulFiltering bool // Apply stateful filtering to inbound connections + NetfilterMode preftype.NetfilterMode // how much to manage netfilter rules + NetfilterKind string // what kind of netfilter to use ("nftables", "iptables", or "" to auto-detect) + RemoveCGNATDropRule bool // whether to remove the firewall rule to drop non-Tailscale inbound traffic from CGNAT IPs } func (a *Config) Equal(b *Config) bool { diff --git a/wgengine/router/router_test.go b/wgengine/router/router_test.go index 28750e115a9e3..e6b41558622fe 100644 --- a/wgengine/router/router_test.go +++ b/wgengine/router/router_test.go @@ -15,12 +15,12 @@ func TestConfigEqual(t *testing.T) { testedFields := []string{ "LocalAddrs", "Routes", "LocalRoutes", "NewMTU", "SubnetRoutes", "SNATSubnetRoutes", "StatefulFiltering", - "NetfilterMode", "NetfilterKind", + "NetfilterMode", "NetfilterKind", "RemoveCGNATDropRule", } configType := reflect.TypeFor[Config]() configFields := []string{} - for i := range configType.NumField() { - configFields = append(configFields, configType.Field(i).Name) + for field := range configType.Fields() { + configFields = append(configFields, field.Name) } if !reflect.DeepEqual(configFields, testedFields) { t.Errorf("Config.Equal check might be out of sync\nfields: %q\nhandled: %q\n", diff --git a/wgengine/userspace.go b/wgengine/userspace.go index 245ce421fbe5a..23edf30b379de 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -4,22 +4,21 @@ package wgengine import ( - "bufio" "context" crand "crypto/rand" + "crypto/x509" "errors" "fmt" "io" - "maps" "math" "net/netip" - "reflect" "runtime" "slices" - "strings" "sync" + "sync/atomic" "time" + "github.com/gaissmai/bart" "github.com/tailscale/wireguard-go/device" "github.com/tailscale/wireguard-go/tun" "tailscale.com/control/controlknobs" @@ -42,6 +41,7 @@ import ( "tailscale.com/tailcfg" "tailscale.com/tstime/mono" "tailscale.com/types/dnstype" + "tailscale.com/types/events" "tailscale.com/types/ipproto" "tailscale.com/types/key" "tailscale.com/types/logger" @@ -68,29 +68,6 @@ import ( "tailscale.com/wgengine/wglog" ) -// Lazy wireguard-go configuration parameters. -const ( - // lazyPeerIdleThreshold is the idle duration after - // which we remove a peer from the wireguard configuration. - // (This includes peers that have never been idle, which - // effectively have infinite idleness) - lazyPeerIdleThreshold = 5 * time.Minute - - // packetSendTimeUpdateFrequency controls how often we record - // the time that we wrote a packet to an IP address. - packetSendTimeUpdateFrequency = 10 * time.Second - - // packetSendRecheckWireguardThreshold controls how long we can go - // between packet sends to an IP before checking to see - // whether this IP address needs to be added back to the - // WireGuard peer oconfig. - packetSendRecheckWireguardThreshold = 1 * time.Minute -) - -// statusPollInterval is how often we ask wireguard-go for its engine -// status (as long as there's activity). See docs on its use below. -const statusPollInterval = 1 * time.Minute - // networkLoggerUploadTimeout is the maximum timeout to wait when // shutting down the network logger as it uploads the last network log messages. const networkLoggerUploadTimeout = 5 * time.Second @@ -120,7 +97,8 @@ type userspaceEngine struct { birdClient BIRDClient // or nil controlKnobs *controlknobs.Knobs // or nil - testMaybeReconfigHook func() // for tests; if non-nil, fires if maybeReconfigWireguardLocked called + testMaybeReconfigHook func() // for tests; if non-nil, fires if maybeReconfigWireguardLocked called + testDiscoChangedHook func(map[key.NodePublic]bool) // for tests; if non-nil, fires after assembling discoChanged map // isLocalAddr reports the whether an IP is assigned to the local // tunnel interface. It's used to reflect local packets @@ -131,21 +109,27 @@ type userspaceEngine struct { // is being routed over Tailscale. isDNSIPOverTailscale syncs.AtomicValue[func(netip.Addr) bool] - wgLock sync.Mutex // serializes all wgdev operations; see lock order comment below - lastCfgFull wgcfg.Config - lastNMinPeers int - lastRouter *router.Config - lastEngineFull *wgcfg.Config // of full wireguard config, not trimmed - lastEngineInputs *maybeReconfigInputs - lastDNSConfig dns.ConfigView // or invalid if none - lastIsSubnetRouter bool // was the node a primary subnet router in the last run. - recvActivityAt map[key.NodePublic]mono.Time - trimmedNodes map[key.NodePublic]bool // set of node keys of peers currently excluded from wireguard config - sentActivityAt map[netip.Addr]*mono.Time // value is accessed atomically - destIPActivityFuncs map[netip.Addr]func() - lastStatusPollTime mono.Time // last time we polled the engine status - reconfigureVPN func() error // or nil - conn25PacketHooks Conn25PacketHooks // or nil + wgLock sync.Mutex // serializes all wgdev operations; see lock order comment below + + // peerByIPRoute is a longest-prefix-match table built from + // lastCfgFull.Peers AllowedIPs. It's the slow path for + // SetPeerByIPPacketFunc, used when LocalBackend's exact-IP fast path + // (nodeByAddr) misses — i.e. for subnet routes and exit-node default + // routes. Built from lastCfgFull (the wireguard-filtered peer list) + // rather than the netmap so that exit-node selection is honored: the + // netmap has 0.0.0.0/0 in AllowedIPs for every exit-capable peer, but + // lastCfgFull only has it for the currently-selected exit node. + // + // Replaced (not mutated) by maybeReconfigWireguardLocked. Read by + // the per-packet wgdev callback without locking. + peerByIPRoute atomic.Pointer[bart.Table[key.NodePublic]] + + lastCfgFull wgcfg.Config + lastRouter *router.Config + lastDNSConfig dns.ConfigView // or invalid if none + lastIsSubnetRouter bool // was the node a primary subnet router in the last run. + reconfigureVPN func() error // or nil + conn25PacketHooks Conn25PacketHooks // or nil mu sync.Mutex // guards following; see lock order comment below netMap *netmap.NetworkMap // or nil @@ -166,6 +150,10 @@ type userspaceEngine struct { // networkLogger logs statistics about network connections. networkLogger netlog.Logger + // tsmpLearnedDisco tracks per node key if a peer disco key was learned via TSMP. + // wgLock must be held when using this map. + tsmpLearnedDisco map[key.NodePublic]key.DiscoPublic + // Lock ordering: magicsock.Conn.mu, wgLock, then mu. } @@ -230,6 +218,10 @@ type Config struct { // If nil, a new Dialer is created. Dialer *tsdial.Dialer + // ExtraRootCAs, if non-nil, specifies additional trusted root CAs for TLS + // connections (e.g. DERP). Passed through to magicsock. + ExtraRootCAs *x509.CertPool + // ControlKnobs is the set of control plane-provied knobs // to use. // If nil, defaults are used. @@ -265,6 +257,20 @@ type Config struct { // Conn25PacketHooks, if non-nil, is used to hook packets for Connectors 2025 // app connector handling logic. Conn25PacketHooks Conn25PacketHooks + + // ForceDiscoKey, if non-zero, forces the use of a specific disco + // private key. This should only be used for special cases and + // experiments, not for production. The recommended normal path is to + // leave it zero, in which case a new disco key is generated per + // Tailscale start and kept only in memory. + ForceDiscoKey key.DiscoPrivate + + // OnDERPRecv, if non-nil, is called for every non-disco packet + // received from DERP before the peer map lookup. If it returns + // true, the packet is considered handled and is not passed to + // WireGuard. The pkt slice is borrowed and must be copied if + // the callee needs to retain it. + OnDERPRecv func(regionID int, src key.NodePublic, pkt []byte) (handled bool) } // NewFakeUserspaceEngine returns a new userspace engine for testing. @@ -430,14 +436,13 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error) IdleFunc: e.tundev.IdleDuration, NetMon: e.netMon, HealthTracker: e.health, + ExtraRootCAs: conf.ExtraRootCAs, Metrics: conf.Metrics, ControlKnobs: conf.ControlKnobs, PeerByKeyFunc: e.PeerByKey, + ForceDiscoKey: conf.ForceDiscoKey, + OnDERPRecv: conf.OnDERPRecv, } - if buildfeatures.HasLazyWG { - magicsockOpts.NoteRecvActivity = e.noteRecvActivity - } - var err error e.magicConn, err = magicsock.NewConn(magicsockOpts) if err != nil { @@ -505,6 +510,16 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error) e.logf("Creating WireGuard device...") e.wgdev = wgcfg.NewDevice(e.tundev, e.magicConn.Bind(), e.wgLogger.DeviceLogger) closePool.addFunc(e.wgdev.Close) + + // Install a default outbound-packet peer lookup callback. It uses only + // the engine's BART table, which is rebuilt from the wireguard-filtered + // peer list on every Reconfig. Consumers (e.g. LocalBackend) may later + // call SetPeerByIPPacketFunc to additionally install a fast path for + // exact node-address matches; the BART remains the slow-path fallback. + // Without this default, callers that don't run a LocalBackend would + // have no way to route outbound packets to peers, since peers are + // created lazily from inbound packets only via SetPeerLookupFunc. + e.SetPeerByIPPacketFunc(nil) closePool.addFunc(func() { if err := e.magicConn.Close(); err != nil { e.logf("error closing magicconn: %v", err) @@ -581,7 +596,7 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error) } e.linkChangeQueue.Add(func() { e.linkChange(&cd) }) }) - eventbus.SubscribeFunc(ec, func(update tstun.DiscoKeyAdvertisement) { + eventbus.SubscribeFunc(ec, func(update events.PeerDiscoKeyUpdate) { e.logf("wgengine: got TSMP disco key advertisement from %v via eventbus", update.Src) if e.magicConn == nil { e.logf("wgengine: no magicConn") @@ -664,135 +679,11 @@ func (e *userspaceEngine) handleLocalPackets(p *packet.Parsed, t *tstun.Wrapper) return filter.Accept } -var debugTrimWireguard = envknob.RegisterOptBool("TS_DEBUG_TRIM_WIREGUARD") - -// forceFullWireguardConfig reports whether we should give wireguard our full -// network map, even for inactive peers. -// -// TODO(bradfitz): remove this at some point. We had a TODO to do it before 1.0 -// but it's still there as of 1.30. Really we should not do this wireguard lazy -// peer config at all and just fix wireguard-go to not have so much extra memory -// usage per peer. That would simplify a lot of Tailscale code. OTOH, we have 50 -// MB of memory on iOS now instead of 15 MB, so the other option is to just give -// up on lazy wireguard config and blow the memory and hope for the best on iOS. -// That's sad too. Or we get rid of these knobs (lazy wireguard config has been -// stable!) but I'm worried that a future regression would be easier to debug -// with these knobs in place. -func (e *userspaceEngine) forceFullWireguardConfig(numPeers int) bool { - // Did the user explicitly enable trimming via the environment variable knob? - if b, ok := debugTrimWireguard().Get(); ok { - return !b - } - return e.controlKnobs != nil && e.controlKnobs.KeepFullWGConfig.Load() -} - -// isTrimmablePeer reports whether p is a peer that we can trim out of the -// network map. -// -// For implementation simplicity, we can only trim peers that have -// only non-subnet AllowedIPs (an IPv4 /32 or IPv6 /128), which is the -// common case for most peers. Subnet router nodes will just always be -// created in the wireguard-go config. -func (e *userspaceEngine) isTrimmablePeer(p *wgcfg.Peer, numPeers int) bool { - if e.forceFullWireguardConfig(numPeers) { - return false - } - - // AllowedIPs must all be single IPs, not subnets. - for _, aip := range p.AllowedIPs { - if !aip.IsSingleIP() { - return false - } - } - return true -} - -// noteRecvActivity is called by magicsock when a packet has been -// received for the peer with node key nk. Magicsock calls this no -// more than every 10 seconds for a given peer. -func (e *userspaceEngine) noteRecvActivity(nk key.NodePublic) { - e.wgLock.Lock() - defer e.wgLock.Unlock() - - if _, ok := e.recvActivityAt[nk]; !ok { - // Not a trimmable peer we care about tracking. (See isTrimmablePeer) - if e.trimmedNodes[nk] { - e.logf("wgengine: [unexpected] noteReceiveActivity called on idle node %v that's not in recvActivityAt", nk.ShortString()) - } - return - } - now := e.timeNow() - e.recvActivityAt[nk] = now - - // As long as there's activity, periodically poll the engine to get - // stats for the far away side effect of - // ipn/ipnlocal.LocalBackend.parseWgStatusLocked to log activity, for - // use in various admin dashboards. - // This particularly matters on platforms without a connected GUI, as - // the GUIs generally poll this enough to cause that logging. But - // tailscaled alone did not, hence this. - if e.lastStatusPollTime.IsZero() || now.Sub(e.lastStatusPollTime) >= statusPollInterval { - e.lastStatusPollTime = now - go e.RequestStatus() - } - - // If the last activity time jumped a bunch (say, at least - // half the idle timeout) then see if we need to reprogram - // WireGuard. This could probably be just - // lazyPeerIdleThreshold without the divide by 2, but - // maybeReconfigWireguardLocked is cheap enough to call every - // couple minutes (just not on every packet). - if e.trimmedNodes[nk] { - e.logf("wgengine: idle peer %v now active, reconfiguring WireGuard", nk.ShortString()) - e.maybeReconfigWireguardLocked(nil) - } -} - -// isActiveSinceLocked reports whether the peer identified by (nk, ip) -// has had a packet sent to or received from it since t. +// maybeReconfigWireguardLocked reconfigures wireguard-go with the current +// full config, installing a PeerLookupFunc for on-demand peer creation. // // e.wgLock must be held. -func (e *userspaceEngine) isActiveSinceLocked(nk key.NodePublic, ip netip.Addr, t mono.Time) bool { - if e.recvActivityAt[nk].After(t) { - return true - } - timePtr, ok := e.sentActivityAt[ip] - if !ok { - return false - } - return timePtr.LoadAtomic().After(t) -} - -// maybeReconfigInputs holds the inputs to the maybeReconfigWireguardLocked -// function. If these things don't change between calls, there's nothing to do. -type maybeReconfigInputs struct { - WGConfig *wgcfg.Config - TrimmedNodes map[key.NodePublic]bool - TrackNodes views.Slice[key.NodePublic] - TrackIPs views.Slice[netip.Addr] -} - -func (i *maybeReconfigInputs) Equal(o *maybeReconfigInputs) bool { - return reflect.DeepEqual(i, o) -} - -func (i *maybeReconfigInputs) Clone() *maybeReconfigInputs { - if i == nil { - return nil - } - v := *i - v.WGConfig = i.WGConfig.Clone() - v.TrimmedNodes = maps.Clone(i.TrimmedNodes) - return &v -} - -// discoChanged are the set of peers whose disco keys have changed, implying they've restarted. -// If a peer is in this set and was previously in the live wireguard config, -// it needs to be first removed and then re-added to flush out its wireguard session key. -// If discoChanged is nil or empty, this extra removal step isn't done. -// -// e.wgLock must be held. -func (e *userspaceEngine) maybeReconfigWireguardLocked(discoChanged map[key.NodePublic]bool) error { +func (e *userspaceEngine) maybeReconfigWireguardLocked() error { if hook := e.testMaybeReconfigHook; hook != nil { hook() return nil @@ -801,177 +692,49 @@ func (e *userspaceEngine) maybeReconfigWireguardLocked(discoChanged map[key.Node full := e.lastCfgFull e.wgLogger.SetPeers(full.Peers) - // Compute a minimal config to pass to wireguard-go - // based on the full config. Prune off all the peers - // and only add the active ones back. - min := full - min.Peers = make([]wgcfg.Peer, 0, e.lastNMinPeers) - - // We'll only keep a peer around if it's been active in - // the past 5 minutes. That's more than WireGuard's key - // rotation time anyway so it's no harm if we remove it - // later if it's been inactive. - var activeCutoff mono.Time - if buildfeatures.HasLazyWG { - activeCutoff = e.timeNow().Add(-lazyPeerIdleThreshold) - } - - // Not all peers can be trimmed from the network map (see - // isTrimmablePeer). For those that are trimmable, keep track of - // their NodeKey and Tailscale IPs. These are the ones we'll need - // to install tracking hooks for to watch their send/receive - // activity. - var trackNodes []key.NodePublic - var trackIPs []netip.Addr - if buildfeatures.HasLazyWG { - trackNodes = make([]key.NodePublic, 0, len(full.Peers)) - trackIPs = make([]netip.Addr, 0, len(full.Peers)) - } - - // Don't re-alloc the map; the Go compiler optimizes map clears as of - // Go 1.11, so we can re-use the existing + allocated map. - if e.trimmedNodes != nil { - clear(e.trimmedNodes) - } else { - e.trimmedNodes = make(map[key.NodePublic]bool) - } - - needRemoveStep := false - for i := range full.Peers { - p := &full.Peers[i] - nk := p.PublicKey - if !buildfeatures.HasLazyWG || !e.isTrimmablePeer(p, len(full.Peers)) { - min.Peers = append(min.Peers, *p) - if discoChanged[nk] { - needRemoveStep = true - } - continue - } - trackNodes = append(trackNodes, nk) - recentlyActive := false - for _, cidr := range p.AllowedIPs { - trackIPs = append(trackIPs, cidr.Addr()) - recentlyActive = recentlyActive || e.isActiveSinceLocked(nk, cidr.Addr(), activeCutoff) - } - if recentlyActive { - min.Peers = append(min.Peers, *p) - if discoChanged[nk] { - needRemoveStep = true - } - } else { - e.trimmedNodes[nk] = true - } - } - e.lastNMinPeers = len(min.Peers) - - if changed := checkchange.Update(&e.lastEngineInputs, &maybeReconfigInputs{ - WGConfig: &min, - TrimmedNodes: e.trimmedNodes, - TrackNodes: views.SliceOf(trackNodes), - TrackIPs: views.SliceOf(trackIPs), - }); !changed { - return nil - } - - if buildfeatures.HasLazyWG { - e.updateActivityMapsLocked(trackNodes, trackIPs) - } - - if needRemoveStep { - minner := min - minner.Peers = nil - numRemove := 0 - for _, p := range min.Peers { - if discoChanged[p.PublicKey] { - numRemove++ - continue - } - minner.Peers = append(minner.Peers, p) - } - if numRemove > 0 { - e.logf("wgengine: Reconfig: removing session keys for %d peers", numRemove) - if err := wgcfg.ReconfigDevice(e.wgdev, &minner, e.logf); err != nil { - e.logf("wgdev.Reconfig: %v", err) - return err - } + // Rebuild the prefix-match peer routing table from the current + // (wireguard-filtered) peer list and publish it atomically. + rt := &bart.Table[key.NodePublic]{} + for _, p := range full.Peers { + for _, pfx := range p.AllowedIPs { + rt.Insert(pfx, p.PublicKey) } } + e.peerByIPRoute.Store(rt) - e.logf("wgengine: Reconfig: configuring userspace WireGuard config (with %d/%d peers)", len(min.Peers), len(full.Peers)) - if err := wgcfg.ReconfigDevice(e.wgdev, &min, e.logf); err != nil { + e.logf("wgengine: Reconfig: configuring userspace WireGuard config (with %d peers)", len(full.Peers)) + if err := wgcfg.ReconfigDevice(e.wgdev, &full, e.logf); err != nil { e.logf("wgdev.Reconfig: %v", err) return err } return nil } -// updateActivityMapsLocked updates the data structures used for tracking the activity -// of wireguard peers that we might add/remove dynamically from the real config -// as given to wireguard-go. +// SetPeerByIPPacketFunc installs a callback used by wireguard-go to look up +// which peer should handle an outbound packet by destination IP. // -// e.wgLock must be held. -func (e *userspaceEngine) updateActivityMapsLocked(trackNodes []key.NodePublic, trackIPs []netip.Addr) { - if !buildfeatures.HasLazyWG { - return - } - // Generate the new map of which nodekeys we want to track - // receive times for. - mr := map[key.NodePublic]mono.Time{} // TODO: only recreate this if set of keys changed - for _, nk := range trackNodes { - // Preserve old times in the new map, but also - // populate map entries for new trackNodes values with - // time.Time{} zero values. (Only entries in this map - // are tracked, so the Time zero values allow it to be - // tracked later) - mr[nk] = e.recvActivityAt[nk] - } - e.recvActivityAt = mr - - oldTime := e.sentActivityAt - e.sentActivityAt = make(map[netip.Addr]*mono.Time, len(oldTime)) - oldFunc := e.destIPActivityFuncs - e.destIPActivityFuncs = make(map[netip.Addr]func(), len(oldFunc)) - - updateFn := func(timePtr *mono.Time) func() { - return func() { - now := e.timeNow() - old := timePtr.LoadAtomic() - - // How long's it been since we last sent a packet? - elapsed := now.Sub(old) - if old == 0 { - // For our first packet, old is 0, which has indeterminate meaning. - // Set elapsed to a big number (four score and seven years). - elapsed = 762642 * time.Hour - } - - if elapsed >= packetSendTimeUpdateFrequency { - timePtr.StoreAtomic(now) - } - // On a big jump, assume we might no longer be in the wireguard - // config and go check. - if elapsed >= packetSendRecheckWireguardThreshold { - e.wgLock.Lock() - defer e.wgLock.Unlock() - e.maybeReconfigWireguardLocked(nil) +// fn is an optional fast path for exact node-address matches (e.g. dst is a +// Tailscale IP). On miss (or if fn is nil), the engine's own BART table +// ([userspaceEngine.peerByIPRoute], built from the wireguard-filtered peer +// list) is consulted to handle subnet routes and exit-node default routes. +// +// [NewUserspaceEngine] installs a BART-only default at engine creation time, +// so callers that don't call SetPeerByIPPacketFunc (e.g. those not running +// a LocalBackend) still get working outbound packet routing. +func (e *userspaceEngine) SetPeerByIPPacketFunc(fn func(netip.Addr) (_ key.NodePublic, ok bool)) { + e.wgdev.SetPeerByIPPacketFunc(func(_, dst netip.Addr, _ []byte) (device.NoisePublicKey, bool) { + if fn != nil { + if pk, ok := fn(dst); ok { + return pk.Raw32(), true } } - } - - for _, ip := range trackIPs { - timePtr := oldTime[ip] - if timePtr == nil { - timePtr = new(mono.Time) - } - e.sentActivityAt[ip] = timePtr - - fn := oldFunc[ip] - if fn == nil { - fn = updateFn(timePtr) + if rt := e.peerByIPRoute.Load(); rt != nil { + if pk, ok := rt.Lookup(dst); ok { + return pk.Raw32(), true + } } - e.destIPActivityFuncs[ip] = fn - } - e.tundev.SetDestIPActivityFuncs(e.destIPActivityFuncs) + return device.NoisePublicKey{}, false + }) } // hasOverlap checks if there is a IPPrefix which is common amongst the two @@ -1011,6 +774,12 @@ func (e *userspaceEngine) ResetAndStop() (*Status, error) { } } +func (e *userspaceEngine) PatchDiscoKey(pub key.NodePublic, disco key.DiscoPublic) { + e.wgLock.Lock() + defer e.wgLock.Unlock() + mak.Set(&e.tsmpLearnedDisco, pub, disco) +} + func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, dnsCfg *dns.Config) error { if routerCfg == nil { panic("routerCfg must not be nil") @@ -1054,7 +823,7 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, } isSubnetRouterChanged := buildfeatures.HasAdvertiseRoutes && isSubnetRouter != e.lastIsSubnetRouter - engineChanged := checkchange.Update(&e.lastEngineFull, cfg) + engineChanged := !e.lastCfgFull.Equal(cfg) routerChanged := checkchange.Update(&e.lastRouter, routerCfg) dnsChanged := buildfeatures.HasDNS && !e.lastDNSConfig.Equal(dnsCfg.View()) if dnsChanged { @@ -1086,11 +855,10 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, } // See if any peers have changed disco keys, which means they've restarted. - // If so, we need to update the wireguard-go/device.Device in two phases: - // once without the node which has restarted, to clear its wireguard session key, - // and a second time with it. + // If so, remove the peer from wireguard-go to flush its session key, + // then let the PeerLookupFunc re-create it on demand. discoChanged := make(map[key.NodePublic]bool) - { + if engineChanged { prevEP := make(map[key.NodePublic]key.DiscoPublic) for i := range e.lastCfgFull.Peers { if p := &e.lastCfgFull.Peers[i]; !p.DiscoKey.IsZero() { @@ -1102,29 +870,84 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, if p.DiscoKey.IsZero() { continue } + pub := p.PublicKey + if old, ok := prevEP[pub]; ok && old != p.DiscoKey { + // If the disco key was learned via TSMP, we do not need to reset the + // wireguard config as the new key was received over an existing wireguard + // connection. + if discoTSMP, okTSMP := e.tsmpLearnedDisco[p.PublicKey]; okTSMP { + // Key matches, remove entry from map. + delete(e.tsmpLearnedDisco, p.PublicKey) + if discoTSMP == p.DiscoKey { + e.logf("wgengine: Skipping reconfig (TSMP key): %s changed from %q to %q", + pub.ShortString(), old, p.DiscoKey) + // Skip session clear. + continue + } + + // The new disco key does not match what we received via + // TSMP for this peer. This is unexpected, though possible + // if processing a change in a large netmap ends up taking + // longer than the 2 second timeout in + // [controlClient.mapRoutineState.UpdateNetmapDelta], or if + // the context is cancelled mid update. Log the event, and reset + // the connection as it is possibly a stale entry in the map + // instead of a TSMP disco key update that led us here. + e.logf("wgengine: [unexpected] Reconfig: using TSMP key for %s (control stale): tsmp=%q control=%q old=%q", + pub.ShortString(), discoTSMP, p.DiscoKey, old) + metricTSMPLearnedKeyMismatch.Add(1) + } + discoChanged[pub] = true e.logf("wgengine: Reconfig: %s changed from %q to %q", pub.ShortString(), old, p.DiscoKey) } } } - e.lastCfgFull = *cfg.Clone() + // For tests, what disco connections needs to be changed. + if e.testDiscoChangedHook != nil { + e.testDiscoChangedHook(discoChanged) + } + + if !e.lastCfgFull.PrivateKey.Equal(cfg.PrivateKey) { + // Tell magicsock about the new (or initial) private key + // (which is needed by DERP) before wgdev gets it, as wgdev + // will start trying to handshake, which we want to be able to + // go over DERP. + if err := e.magicConn.SetPrivateKey(cfg.PrivateKey); err != nil { + e.logf("wgengine: Reconfig: SetPrivateKey: %v", err) + } - // Tell magicsock about the new (or initial) private key - // (which is needed by DERP) before wgdev gets it, as wgdev - // will start trying to handshake, which we want to be able to - // go over DERP. - if err := e.magicConn.SetPrivateKey(cfg.PrivateKey); err != nil { - e.logf("wgengine: Reconfig: SetPrivateKey: %v", err) + if err := e.wgdev.SetPrivateKey(key.NodePrivateAs[device.NoisePrivateKey](cfg.PrivateKey)); err != nil { + e.logf("wgengine: Reconfig: wgdev.SetPrivateKey: %v", err) + } } + + e.lastCfgFull = *cfg.Clone() + e.magicConn.UpdatePeers(peerSet) e.magicConn.SetPreferredPort(listenPort) e.magicConn.UpdatePMTUD() - if err := e.maybeReconfigWireguardLocked(discoChanged); err != nil { - return err + if engineChanged { + if err := e.maybeReconfigWireguardLocked(); err != nil { + return err + } + // Now that we've reconfigured wireguard-go, remove any peers with + // changed disco keys to flush their session keys, and let them be + // re-created on demand by the PeerLookupFunc. + for pub := range discoChanged { + e.wgdev.RemovePeer(pub.Raw32()) + } + } + + // Cleanup map of tsmp marks for peers that no longer exists in config. + for nodeKey := range e.tsmpLearnedDisco { + if !peerSet.Contains(nodeKey) { + delete(e.tsmpLearnedDisco, nodeKey) + } } // Shutdown the network logger because the IDs changed. @@ -1262,8 +1085,14 @@ func (e *userspaceEngine) PeerByKey(pubKey key.NodePublic) (_ wgint.Peer, ok boo if dev == nil { return wgint.Peer{}, false } - peer := dev.LookupPeer(pubKey.Raw32()) - if peer == nil { + // Use LookupActivePeer (not LookupPeer) to avoid triggering on-demand + // peer creation via PeerLookupFunc. PeerByKey is called from status + // polling paths (getStatus, getPeerStatusLite) which iterate every peer + // in the netmap; using LookupPeer would lazily create a wireguard-go + // peer for every single netmap peer on each status poll, leaking + // memory via per-peer queues and goroutines. + peer, ok := dev.LookupActivePeer(pubKey.Raw32()) + if !ok { return wgint.Peer{}, false } return wgint.PeerOf(peer), true @@ -1359,8 +1188,6 @@ func (e *userspaceEngine) Close() { e.closing = true e.mu.Unlock() - r := bufio.NewReader(strings.NewReader("")) - e.wgdev.IpcSetOperation(r) e.magicConn.Close() if e.netMonOwned { e.netMon.Close() @@ -1823,6 +1650,8 @@ var ( metricTSMPDiscoKeyAdvertisementSent = clientmetric.NewCounter("magicsock_tsmp_disco_key_advertisement_sent") metricTSMPDiscoKeyAdvertisementError = clientmetric.NewCounter("magicsock_tsmp_disco_key_advertisement_error") + + metricTSMPLearnedKeyMismatch = clientmetric.NewCounter("magicsock_tsmp_learned_key_mismatch") ) func (e *userspaceEngine) InstallCaptureHook(cb packet.CaptureCallback) { diff --git a/wgengine/userspace_test.go b/wgengine/userspace_test.go index b06ea527b27ba..b2f40fada379f 100644 --- a/wgengine/userspace_test.go +++ b/wgengine/userspace_test.go @@ -5,10 +5,12 @@ package wgengine import ( "fmt" + "math/rand" "net/netip" "os" - "reflect" "runtime" + "slices" + "sync" "testing" "go4.org/mem" @@ -17,81 +19,22 @@ import ( "tailscale.com/envknob" "tailscale.com/health" "tailscale.com/net/dns" + "tailscale.com/net/dns/resolver" "tailscale.com/net/netaddr" - "tailscale.com/net/tstun" + "tailscale.com/net/netmon" "tailscale.com/tailcfg" - "tailscale.com/tstest" - "tailscale.com/tstime/mono" + "tailscale.com/types/dnstype" "tailscale.com/types/key" + "tailscale.com/types/logger" "tailscale.com/types/netmap" "tailscale.com/types/opt" + "tailscale.com/util/dnsname" "tailscale.com/util/eventbus/eventbustest" "tailscale.com/util/usermetric" "tailscale.com/wgengine/router" "tailscale.com/wgengine/wgcfg" ) -func TestNoteReceiveActivity(t *testing.T) { - now := mono.Time(123456) - var logBuf tstest.MemLogger - - confc := make(chan bool, 1) - gotConf := func() bool { - select { - case <-confc: - return true - default: - return false - } - } - e := &userspaceEngine{ - timeNow: func() mono.Time { return now }, - recvActivityAt: map[key.NodePublic]mono.Time{}, - logf: logBuf.Logf, - tundev: new(tstun.Wrapper), - testMaybeReconfigHook: func() { confc <- true }, - trimmedNodes: map[key.NodePublic]bool{}, - } - ra := e.recvActivityAt - - nk := key.NewNode().Public() - - // Activity on an untracked key should do nothing. - e.noteRecvActivity(nk) - if len(ra) != 0 { - t.Fatalf("unexpected growth in map: now has %d keys; want 0", len(ra)) - } - if logBuf.Len() != 0 { - t.Fatalf("unexpected log write (and thus activity): %s", logBuf.Bytes()) - } - - // Now track it, but don't mark it trimmed, so shouldn't update. - ra[nk] = 0 - e.noteRecvActivity(nk) - if len(ra) != 1 { - t.Fatalf("unexpected growth in map: now has %d keys; want 1", len(ra)) - } - if got := ra[nk]; got != now { - t.Fatalf("time in map = %v; want %v", got, now) - } - if gotConf() { - t.Fatalf("unexpected reconfig") - } - - // Now mark it trimmed and expect an update. - e.trimmedNodes[nk] = true - e.noteRecvActivity(nk) - if len(ra) != 1 { - t.Fatalf("unexpected growth in map: now has %d keys; want 1", len(ra)) - } - if got := ra[nk]; got != now { - t.Fatalf("time in map = %v; want %v", got, now) - } - if !gotConf() { - t.Fatalf("didn't get expected reconfig") - } -} - func nodeViews(v []*tailcfg.Node) []tailcfg.NodeView { nv := make([]tailcfg.NodeView, len(v)) for i, n := range v { @@ -110,7 +53,6 @@ func TestUserspaceEngineReconfig(t *testing.T) { t.Fatal(err) } t.Cleanup(e.Close) - ue := e.(*userspaceEngine) routerCfg := &router.Config{} @@ -146,19 +88,166 @@ func TestUserspaceEngineReconfig(t *testing.T) { if err != nil { t.Fatal(err) } + } +} + +func TestUserspaceEngineTSMPLearned(t *testing.T) { + bus := eventbustest.NewBus(t) + + ht := health.NewTracker(bus) + reg := new(usermetric.Registry) + e, err := NewFakeUserspaceEngine(t.Logf, 0, ht, reg, bus) + if err != nil { + t.Fatal(err) + } + t.Cleanup(e.Close) + ue := e.(*userspaceEngine) + + discoChangedChan := make(chan map[key.NodePublic]bool, 1) + ue.testDiscoChangedHook = func(m map[key.NodePublic]bool) { + discoChangedChan <- m + } + + routerCfg := &router.Config{} - wantRecvAt := map[key.NodePublic]mono.Time{ - nkFromHex(nodeHex): 0, + keyChanges := []struct { + tsmp bool + inMap bool + }{ + {tsmp: false, inMap: false}, + {tsmp: true, inMap: false}, + {tsmp: false, inMap: true}, + } + + nkHex := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + for _, change := range keyChanges { + oldDisco := key.NewDisco() + nm := &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + { + ID: 1, + Key: nkFromHex(nkHex), + DiscoKey: oldDisco.Public(), + }, + }), } - if got := ue.recvActivityAt; !reflect.DeepEqual(got, wantRecvAt) { - t.Errorf("wrong recvActivityAt\n got: %v\nwant: %v\n", got, wantRecvAt) + nk, err := key.ParseNodePublicUntyped(mem.S(nkHex)) + if err != nil { + t.Fatal(err) } + e.SetNetworkMap(nm) - wantTrimmedNodes := map[key.NodePublic]bool{ - nkFromHex(nodeHex): true, + newDisco := key.NewDisco() + cfg := &wgcfg.Config{ + Peers: []wgcfg.Peer{ + { + PublicKey: nk, + DiscoKey: newDisco.Public(), + }, + }, } - if got := ue.trimmedNodes; !reflect.DeepEqual(got, wantTrimmedNodes) { - t.Errorf("wrong wantTrimmedNodes\n got: %v\nwant: %v\n", got, wantTrimmedNodes) + + if change.tsmp { + ue.PatchDiscoKey(nk, newDisco.Public()) + } + err = e.Reconfig(cfg, routerCfg, &dns.Config{}) + if err != nil { + t.Fatal(err) + } + + changeMap := <-discoChangedChan + + if _, ok := changeMap[nk]; ok != change.inMap { + t.Fatalf("expect key %v in map %v to be %t, got %t", nk, changeMap, + change.inMap, ok) + } + } +} + +func TestUserspaceEngineTSMPLearnedMismatch(t *testing.T) { + bus := eventbustest.NewBus(t) + + ht := health.NewTracker(bus) + reg := new(usermetric.Registry) + e, err := NewFakeUserspaceEngine(t.Logf, 0, ht, reg, bus) + if err != nil { + t.Fatal(err) + } + t.Cleanup(e.Close) + ue := e.(*userspaceEngine) + + discoChangedChan := make(chan map[key.NodePublic]bool, 1) + ue.testDiscoChangedHook = func(m map[key.NodePublic]bool) { + discoChangedChan <- m + } + + routerCfg := &router.Config{} + var metricValue int64 = 0 + + keyChanges := []struct { + tsmp bool + inMap bool + wrongKey bool + }{ + {tsmp: false, inMap: false, wrongKey: false}, + {tsmp: true, inMap: false, wrongKey: false}, + {tsmp: true, inMap: true, wrongKey: true}, + {tsmp: false, inMap: true, wrongKey: false}, + } + + nkHex := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + for _, change := range keyChanges { + oldDisco := key.NewDisco() + nm := &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + { + ID: 1, + Key: nkFromHex(nkHex), + DiscoKey: oldDisco.Public(), + }, + }), + } + nk, err := key.ParseNodePublicUntyped(mem.S(nkHex)) + if err != nil { + t.Fatal(err) + } + e.SetNetworkMap(nm) + + newDisco := key.NewDisco() + cfg := &wgcfg.Config{ + Peers: []wgcfg.Peer{ + { + PublicKey: nk, + DiscoKey: newDisco.Public(), + }, + }, + } + + tsmpKey := newDisco.Public() + if change.tsmp { + if change.wrongKey { + tsmpKey = key.NewDisco().Public() + } + ue.PatchDiscoKey(nk, tsmpKey) + } + err = e.Reconfig(cfg, routerCfg, &dns.Config{}) + if err != nil { + t.Fatal(err) + } + + changeMap := <-discoChangedChan + + if _, ok := changeMap[nk]; ok != change.inMap { + t.Fatalf("expect key %v in map %v to be %t, got %t", nk, changeMap, + change.inMap, ok) + } + + metric := metricTSMPLearnedKeyMismatch.Value() + delta := metric - metricValue + metricValue = metric + + if change.wrongKey && delta != 1 { + t.Fatalf("expected a delta of 1, got %d", delta) } } } @@ -175,8 +264,8 @@ func TestUserspaceEnginePortReconfig(t *testing.T) { var ue *userspaceEngine ht := health.NewTracker(bus) reg := new(usermetric.Registry) - for i := range 100 { - attempt := uint16(defaultPort + i) + for range 100 { + attempt := uint16(defaultPort + rand.Intn(1000)) e, err := NewFakeUserspaceEngine(t.Logf, attempt, &knobs, ht, reg, bus) if err != nil { t.Fatal(err) @@ -448,3 +537,76 @@ func BenchmarkGenLocalAddrFunc(b *testing.B) { }) b.Logf("x = %v", x) } + +// Regression test for #19730: on major link change, MatchDomains Routes must +// be preserved. +func TestLinkChangeReapplyPreservesMagicDNSRoutes(t *testing.T) { + switch runtime.GOOS { + case "linux", "android", "darwin", "ios", "openbsd": + default: + t.Skipf("linkChange DNS reapply path not exercised on %s", runtime.GOOS) + } + + bus := eventbustest.NewBus(t) + noop, err := dns.NewNoopManager() + if err != nil { + t.Fatal(err) + } + e, err := NewUserspaceEngine(t.Logf, Config{ + HealthTracker: health.NewTracker(bus), + Metrics: new(usermetric.Registry), + EventBus: bus, + DNS: noop, + RespondToPing: true, + }) + if err != nil { + t.Fatal(err) + } + t.Cleanup(e.Close) + + var ( + mu sync.Mutex + last resolver.Config + ) + e.(*userspaceEngine).dns.Resolver().TestOnlySetHook(func(cfg resolver.Config) { + mu.Lock() + defer mu.Unlock() + last = cfg + }) + snapshot := func() []dnsname.FQDN { + mu.Lock() + defer mu.Unlock() + return slices.Clone(last.LocalDomains) + } + + dnsCfg := &dns.Config{ + Routes: map[dnsname.FQDN][]*dnstype.Resolver{ + "ts.net.": {{Addr: "199.247.155.53"}}, + "foo.ts.net.": nil, + "64.100.in-addr.arpa.": nil, + }, + Hosts: map[dnsname.FQDN][]netip.Addr{ + "node.foo.ts.net.": {netip.MustParseAddr("100.64.0.5")}, + }, + SearchDomains: []dnsname.FQDN{"foo.ts.net."}, + } + if err := e.Reconfig(&wgcfg.Config{}, &router.Config{}, dnsCfg); err != nil { + t.Fatalf("Reconfig: %v", err) + } + initial := snapshot() + + cd, err := netmon.NewChangeDelta(nil, &netmon.State{HaveV4: true}, 0, true) + if err != nil { + t.Fatal(err) + } + cd.RebindLikelyRequired = true + e.(*userspaceEngine).linkChange(cd) + + after := snapshot() + slices.Sort(initial) + slices.Sort(after) + if !slices.Equal(initial, after) { + t.Errorf("resolver LocalDomains changed after linkChange:\n initial: %s\n after: %s", + logger.AsJSON(initial), logger.AsJSON(after)) + } +} diff --git a/wgengine/watchdog.go b/wgengine/watchdog.go index f12b1c19e2764..6aa1c1bd64cf0 100644 --- a/wgengine/watchdog.go +++ b/wgengine/watchdog.go @@ -215,6 +215,10 @@ func (e *watchdogEngine) SetNetworkMap(nm *netmap.NetworkMap) { e.watchdog(SetNetworkMap, func() { e.wrap.SetNetworkMap(nm) }) } +func (e *watchdogEngine) SetPeerByIPPacketFunc(fn func(netip.Addr) (_ key.NodePublic, ok bool)) { + e.wrap.SetPeerByIPPacketFunc(fn) +} + func (e *watchdogEngine) Ping(ip netip.Addr, pingType tailcfg.PingType, size int, cb func(*ipnstate.PingResult)) { e.watchdog(Ping, func() { e.wrap.Ping(ip, pingType, size, cb) }) } @@ -242,3 +246,15 @@ func (e *watchdogEngine) InstallCaptureHook(cb packet.CaptureCallback) { func (e *watchdogEngine) PeerByKey(pubKey key.NodePublic) (_ wgint.Peer, ok bool) { return e.wrap.PeerByKey(pubKey) } + +func (e *watchdogEngine) PatchDiscoKey(pub key.NodePublic, disco key.DiscoPublic) { + // PatchDiscoKey mirrors the implementation of [controlclient.patchDiscoKeyer ]. + // It is implemented here to avoid the dependency edge to controlclient, but must be kept + // in sync with the original implementation. + type patchDiscoKeyer interface { + PatchDiscoKey(key.NodePublic, key.DiscoPublic) + } + if n, ok := e.wrap.(patchDiscoKeyer); ok { + n.PatchDiscoKey(pub, disco) + } +} diff --git a/wgengine/watchdog_test.go b/wgengine/watchdog_test.go index 8032339573e90..a0ce9cf079652 100644 --- a/wgengine/watchdog_test.go +++ b/wgengine/watchdog_test.go @@ -26,7 +26,7 @@ func TestWatchdog(t *testing.T) { maxWaitMultiple = 15 } - t.Run("default watchdog does not fire", func(t *testing.T) { + t.Run("default-watchdog-does-not-fire", func(t *testing.T) { t.Parallel() bus := eventbustest.NewBus(t) ht := health.NewTracker(bus) @@ -55,7 +55,7 @@ func TestWatchdogMetrics(t *testing.T) { wantCounts map[watchdogEvent]int64 }{ { - name: "single event types", + name: "single-event-types", events: []watchdogEvent{RequestStatus, PeerForIPEvent, Ping}, wantCounts: map[watchdogEvent]int64{ RequestStatus: 1, @@ -64,7 +64,7 @@ func TestWatchdogMetrics(t *testing.T) { }, }, { - name: "repeated events", + name: "repeated-events", events: []watchdogEvent{RequestStatus, RequestStatus, Ping, RequestStatus}, wantCounts: map[watchdogEvent]int64{ RequestStatus: 3, diff --git a/wgengine/wgcfg/config.go b/wgengine/wgcfg/config.go index 7828121390fba..5510b65b2a199 100644 --- a/wgengine/wgcfg/config.go +++ b/wgengine/wgcfg/config.go @@ -53,11 +53,6 @@ type Peer struct { V6MasqAddr *netip.Addr // if non-nil, masquerade IPv6 traffic to this peer using this address IsJailed bool // if true, this peer is jailed and cannot initiate connections PersistentKeepalive uint16 // in seconds between keep-alives; 0 to disable - // wireguard-go's endpoint for this peer. It should always equal Peer.PublicKey. - // We represent it explicitly so that we can detect if they diverge and recover. - // There is no need to set WGEndpoint explicitly when constructing a Peer by hand. - // It is only populated when reading Peers from wireguard-go. - WGEndpoint key.NodePublic } func addrPtrEq(a, b *netip.Addr) bool { @@ -74,8 +69,7 @@ func (p Peer) Equal(o Peer) bool { p.IsJailed == o.IsJailed && p.PersistentKeepalive == o.PersistentKeepalive && addrPtrEq(p.V4MasqAddr, o.V4MasqAddr) && - addrPtrEq(p.V6MasqAddr, o.V6MasqAddr) && - p.WGEndpoint == o.WGEndpoint + addrPtrEq(p.V6MasqAddr, o.V6MasqAddr) } // PeerWithKey returns the Peer with key k and reports whether it was found. diff --git a/wgengine/wgcfg/config_test.go b/wgengine/wgcfg/config_test.go index b15b8cbf56f8b..013d3a4b49a6e 100644 --- a/wgengine/wgcfg/config_test.go +++ b/wgengine/wgcfg/config_test.go @@ -12,8 +12,7 @@ import ( // that might get added in the future. func TestConfigEqual(t *testing.T) { rt := reflect.TypeFor[Config]() - for i := range rt.NumField() { - sf := rt.Field(i) + for sf := range rt.Fields() { switch sf.Name { case "Name", "NodeID", "PrivateKey", "MTU", "Addresses", "DNS", "Peers", "NetworkLogging": @@ -28,11 +27,10 @@ func TestConfigEqual(t *testing.T) { // that might get added in the future. func TestPeerEqual(t *testing.T) { rt := reflect.TypeFor[Peer]() - for i := range rt.NumField() { - sf := rt.Field(i) + for sf := range rt.Fields() { switch sf.Name { case "PublicKey", "DiscoKey", "AllowedIPs", "IsJailed", - "PersistentKeepalive", "V4MasqAddr", "V6MasqAddr", "WGEndpoint": + "PersistentKeepalive", "V4MasqAddr", "V6MasqAddr": // These are compared in [Peer.Equal]. default: t.Errorf("Have you added field %q to Peer.Equal? Do so if not, and then update TestPeerEqual", sf.Name) diff --git a/wgengine/wgcfg/device.go b/wgengine/wgcfg/device.go index ba29cfbdca8c0..ed32f8337e43e 100644 --- a/wgengine/wgcfg/device.go +++ b/wgengine/wgcfg/device.go @@ -4,9 +4,8 @@ package wgcfg import ( - "errors" - "io" - "sort" + "fmt" + "net/netip" "github.com/tailscale/wireguard-go/conn" "github.com/tailscale/wireguard-go/device" @@ -21,27 +20,15 @@ func NewDevice(tunDev tun.Device, bind conn.Bind, logger *device.Logger) *device return ret } -func DeviceConfig(d *device.Device) (*Config, error) { - r, w := io.Pipe() - errc := make(chan error, 1) - go func() { - errc <- d.IpcGetOperation(w) - w.Close() - }() - cfg, fromErr := FromUAPI(r) - r.Close() - getErr := <-errc - err := errors.Join(getErr, fromErr) - if err != nil { - return nil, err - } - sort.Slice(cfg.Peers, func(i, j int) bool { - return cfg.Peers[i].PublicKey.Less(cfg.Peers[j].PublicKey) - }) - return cfg, nil -} - // ReconfigDevice replaces the existing device configuration with cfg. +// +// Instead of using the UAPI text protocol, it uses the wireguard-go direct API +// to install a [device.PeerLookupFunc] callback that creates peers on demand. +// +// The caller is responsible for: +// - calling [device.Device.SetPrivateKey] when the key changes +// - installing a [device.PeerByIPPacketFunc] on the device for outbound +// packet routing (e.g. via [tailscale.com/wgengine.Engine.SetPeerByIPPacketFunc]) func ReconfigDevice(d *device.Device, cfg *Config, logf logger.Logf) (err error) { defer func() { if err != nil { @@ -49,20 +36,52 @@ func ReconfigDevice(d *device.Device, cfg *Config, logf logger.Logf) (err error) } }() - prev, err := DeviceConfig(d) - if err != nil { - return err + // Build peer map: public key → allowed IPs. + peers := make(map[device.NoisePublicKey][]netip.Prefix, len(cfg.Peers)) + for _, p := range cfg.Peers { + peers[p.PublicKey.Raw32()] = p.AllowedIPs } - r, w := io.Pipe() - errc := make(chan error, 1) - go func() { - errc <- d.IpcSetOperation(r) - r.Close() - }() + // Remove peers not in the new config. + d.RemoveMatchingPeers(func(pk device.NoisePublicKey) bool { + _, exists := peers[pk] + return !exists + }) + + // Update AllowedIPs on any already-active peers whose config may have + // changed. Peers that don't exist yet will get the correct AllowedIPs + // from PeerLookupFunc when they are lazily created. + for pk, allowedIPs := range peers { + if peer, ok := d.LookupActivePeer(pk); ok { + peer.SetAllowedIPs(allowedIPs) + } + } + + // Install callback for lazy peer creation (incoming packets). + bind := d.Bind() + d.SetPeerLookupFunc(func(pubk device.NoisePublicKey) (_ *device.NewPeerConfig, ok bool) { + allowedIPs, ok := peers[pubk] + if !ok { + return nil, false + } + ep, err := bind.ParseEndpoint(fmt.Sprintf("%x", pubk[:])) + if err != nil { + logf("wgcfg: failed to parse endpoint for peer %x: %v", pubk[:8], err) + return nil, false + } + return &device.NewPeerConfig{ + AllowedIPs: allowedIPs, + Endpoint: ep, + }, true + }) + + // RemoveMatchingPeers _again_, now that SetPeerLookupFunc is installed, + // lest any removed peers got re-created before the new SetPeerLookupFunc + // func was installed. + d.RemoveMatchingPeers(func(pk device.NoisePublicKey) bool { + _, exists := peers[pk] + return !exists + }) - toErr := cfg.ToUAPI(logf, w, prev) - w.Close() - setErr := <-errc - return errors.Join(setErr, toErr) + return nil } diff --git a/wgengine/wgcfg/device_test.go b/wgengine/wgcfg/device_test.go index a0443147db80d..07eb41adbbdca 100644 --- a/wgengine/wgcfg/device_test.go +++ b/wgengine/wgcfg/device_test.go @@ -4,33 +4,22 @@ package wgcfg import ( - "bufio" - "bytes" "io" "net/netip" "os" - "sort" - "strings" - "sync" "testing" "github.com/tailscale/wireguard-go/conn" "github.com/tailscale/wireguard-go/device" "github.com/tailscale/wireguard-go/tun" - "go4.org/mem" "tailscale.com/types/key" ) -func TestDeviceConfig(t *testing.T) { - newK := func() (key.NodePublic, key.NodePrivate) { - t.Helper() - k := key.NewNode() - return k.Public(), k - } +func TestReconfigDevice(t *testing.T) { k1, pk1 := newK() ip1 := netip.MustParsePrefix("10.0.0.1/32") - k2, pk2 := newK() + k2, _ := newK() ip2 := netip.MustParsePrefix("10.0.0.2/32") k3, _ := newK() @@ -38,165 +27,80 @@ func TestDeviceConfig(t *testing.T) { cfg1 := &Config{ PrivateKey: pk1, - Peers: []Peer{{ - PublicKey: k2, - AllowedIPs: []netip.Prefix{ip2}, - }}, - } - - cfg2 := &Config{ - PrivateKey: pk2, - Peers: []Peer{{ - PublicKey: k1, - AllowedIPs: []netip.Prefix{ip1}, - PersistentKeepalive: 5, - }}, + Peers: []Peer{ + {PublicKey: k2, AllowedIPs: []netip.Prefix{ip2}}, + }, } - device1 := NewDevice(newNilTun(), new(noopBind), device.NewLogger(device.LogLevelError, "device1")) - device2 := NewDevice(newNilTun(), new(noopBind), device.NewLogger(device.LogLevelError, "device2")) - defer device1.Close() - defer device2.Close() + dev := NewDevice(newNilTun(), new(noopBind), device.NewLogger(device.LogLevelError, "test")) + defer dev.Close() - cmp := func(t *testing.T, d *device.Device, want *Config) { - t.Helper() - got, err := DeviceConfig(d) - if err != nil { + t.Run("initial-config", func(t *testing.T) { + if err := ReconfigDevice(dev, cfg1, t.Logf); err != nil { t.Fatal(err) } - prev := new(Config) - gotbuf := new(strings.Builder) - err = got.ToUAPI(t.Logf, gotbuf, prev) - gotStr := gotbuf.String() - if err != nil { - t.Errorf("got.ToUAPI(): error: %v", err) - return - } - wantbuf := new(strings.Builder) - err = want.ToUAPI(t.Logf, wantbuf, prev) - wantStr := wantbuf.String() - if err != nil { - t.Errorf("want.ToUAPI(): error: %v", err) - return - } - if gotStr != wantStr { - buf := new(bytes.Buffer) - w := bufio.NewWriter(buf) - if err := d.IpcGetOperation(w); err != nil { - t.Errorf("on error, could not IpcGetOperation: %v", err) - } - w.Flush() - t.Errorf("config mismatch:\n---- got:\n%s\n---- want:\n%s\n---- uapi:\n%s", gotStr, wantStr, buf.String()) - } - } - - t.Run("device1 config", func(t *testing.T) { - if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { - t.Fatal(err) - } - cmp(t, device1, cfg1) - }) - - t.Run("device2 config", func(t *testing.T) { - if err := ReconfigDevice(device2, cfg2, t.Logf); err != nil { - t.Fatal(err) - } - cmp(t, device2, cfg2) - }) - - // This is only to test that Config and Reconfig are properly synchronized. - t.Run("device2 config/reconfig", func(t *testing.T) { - var wg sync.WaitGroup - wg.Add(2) - - go func() { - ReconfigDevice(device2, cfg2, t.Logf) - wg.Done() - }() - - go func() { - DeviceConfig(device2) - wg.Done() - }() - - wg.Wait() - }) - - t.Run("device1 modify peer", func(t *testing.T) { - cfg1.Peers[0].DiscoKey = key.DiscoPublicFromRaw32(mem.B([]byte{0: 1, 31: 0})) - if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { - t.Fatal(err) + // Peer should be creatable on demand via LookupPeer. + peer := dev.LookupPeer(k2.Raw32()) + if peer == nil { + t.Fatal("expected peer k2 to exist via LookupPeer") } - cmp(t, device1, cfg1) - }) - - t.Run("device1 replace endpoint", func(t *testing.T) { - cfg1.Peers[0].DiscoKey = key.DiscoPublicFromRaw32(mem.B([]byte{0: 2, 31: 0})) - if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { - t.Fatal(err) + // Unknown peer should not be found. + peer = dev.LookupPeer(k3.Raw32()) + if peer != nil { + t.Fatal("expected unknown peer k3 to not exist") } - cmp(t, device1, cfg1) }) - t.Run("device1 add new peer", func(t *testing.T) { + t.Run("add-peer", func(t *testing.T) { cfg1.Peers = append(cfg1.Peers, Peer{ PublicKey: k3, AllowedIPs: []netip.Prefix{ip3}, }) - sort.Slice(cfg1.Peers, func(i, j int) bool { - return cfg1.Peers[i].PublicKey.Less(cfg1.Peers[j].PublicKey) - }) - - origCfg, err := DeviceConfig(device1) - if err != nil { + if err := ReconfigDevice(dev, cfg1, t.Logf); err != nil { t.Fatal(err) } - - if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { - t.Fatal(err) + // Both peers should now be discoverable. + if p := dev.LookupPeer(k2.Raw32()); p == nil { + t.Fatal("expected peer k2 to exist") } - cmp(t, device1, cfg1) - - newCfg, err := DeviceConfig(device1) - if err != nil { - t.Fatal(err) + if p := dev.LookupPeer(k3.Raw32()); p == nil { + t.Fatal("expected peer k3 to exist") } + }) - peer0 := func(cfg *Config) Peer { - p, ok := cfg.PeerWithKey(k2) - if !ok { - t.Helper() - t.Fatal("failed to look up peer 2") - } - return p + t.Run("remove-peer", func(t *testing.T) { + cfg2 := &Config{ + PrivateKey: pk1, + Peers: []Peer{ + {PublicKey: k2, AllowedIPs: []netip.Prefix{ip2}}, + }, } - peersEqual := func(p, q Peer) bool { - return p.PublicKey == q.PublicKey && p.DiscoKey == q.DiscoKey && p.PersistentKeepalive == q.PersistentKeepalive && cidrsEqual(p.AllowedIPs, q.AllowedIPs) + if err := ReconfigDevice(dev, cfg2, t.Logf); err != nil { + t.Fatal(err) + } + // k2 should still be discoverable. + if p := dev.LookupPeer(k2.Raw32()); p == nil { + t.Fatal("expected peer k2 to exist") } - if !peersEqual(peer0(origCfg), peer0(newCfg)) { - t.Error("reconfig modified old peer") + // k3 should no longer be discoverable. + if p := dev.LookupPeer(k3.Raw32()); p != nil { + t.Fatal("expected peer k3 to not exist after removal") } }) - t.Run("device1 remove peer", func(t *testing.T) { - removeKey := cfg1.Peers[len(cfg1.Peers)-1].PublicKey - cfg1.Peers = cfg1.Peers[:len(cfg1.Peers)-1] - - if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { - t.Fatal(err) + t.Run("self-key-not-peer", func(t *testing.T) { + // The device's own key should not be a peer. + if p := dev.LookupPeer(k1.Raw32()); p != nil { + t.Fatal("expected own key to not be a peer") } - cmp(t, device1, cfg1) + }) - newCfg, err := DeviceConfig(device1) - if err != nil { - t.Fatal(err) - } + _ = ip1 // suppress unused +} - _, ok := newCfg.PeerWithKey(removeKey) - if ok { - t.Error("reconfig failed to remove peer") - } - }) +func newK() (key.NodePublic, key.NodePrivate) { + k := key.NewNode() + return k.Public(), k } // TODO: replace with a loopback tunnel diff --git a/wgengine/wgcfg/parser.go b/wgengine/wgcfg/parser.go deleted file mode 100644 index 8fb9214091a42..0000000000000 --- a/wgengine/wgcfg/parser.go +++ /dev/null @@ -1,186 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -package wgcfg - -import ( - "bufio" - "fmt" - "io" - "net" - "net/netip" - "strconv" - "strings" - - "go4.org/mem" - "tailscale.com/types/key" -) - -type ParseError struct { - why string - offender string -} - -func (e *ParseError) Error() string { - return fmt.Sprintf("%s: %q", e.why, e.offender) -} - -func parseEndpoint(s string) (host string, port uint16, err error) { - i := strings.LastIndexByte(s, ':') - if i < 0 { - return "", 0, &ParseError{"Missing port from endpoint", s} - } - host, portStr := s[:i], s[i+1:] - if len(host) < 1 { - return "", 0, &ParseError{"Invalid endpoint host", host} - } - uport, err := strconv.ParseUint(portStr, 10, 16) - if err != nil { - return "", 0, err - } - hostColon := strings.IndexByte(host, ':') - if host[0] == '[' || host[len(host)-1] == ']' || hostColon > 0 { - err := &ParseError{"Brackets must contain an IPv6 address", host} - if len(host) > 3 && host[0] == '[' && host[len(host)-1] == ']' && hostColon > 0 { - maybeV6 := net.ParseIP(host[1 : len(host)-1]) - if maybeV6 == nil || len(maybeV6) != net.IPv6len { - return "", 0, err - } - } else { - return "", 0, err - } - host = host[1 : len(host)-1] - } - return host, uint16(uport), nil -} - -// memROCut separates a mem.RO at the separator if it exists, otherwise -// it returns two empty ROs and reports that it was not found. -func memROCut(s mem.RO, sep byte) (before, after mem.RO, found bool) { - if i := mem.IndexByte(s, sep); i >= 0 { - return s.SliceTo(i), s.SliceFrom(i + 1), true - } - found = false - return -} - -// FromUAPI generates a Config from r. -// r should be generated by calling device.IpcGetOperation; -// it is not compatible with other uapi streams. -func FromUAPI(r io.Reader) (*Config, error) { - cfg := new(Config) - var peer *Peer // current peer being operated on - deviceConfig := true - - scanner := bufio.NewScanner(r) - for scanner.Scan() { - line := mem.B(scanner.Bytes()) - if line.Len() == 0 { - continue - } - key, value, ok := memROCut(line, '=') - if !ok { - return nil, fmt.Errorf("failed to cut line %q on =", line.StringCopy()) - } - valueBytes := scanner.Bytes()[key.Len()+1:] - - if key.EqualString("public_key") { - if deviceConfig { - deviceConfig = false - } - // Load/create the peer we are now configuring. - var err error - peer, err = cfg.handlePublicKeyLine(valueBytes) - if err != nil { - return nil, err - } - continue - } - - var err error - if deviceConfig { - err = cfg.handleDeviceLine(key, value, valueBytes) - } else { - err = cfg.handlePeerLine(peer, key, value, valueBytes) - } - if err != nil { - return nil, err - } - } - - if err := scanner.Err(); err != nil { - return nil, err - } - - return cfg, nil -} - -func (cfg *Config) handleDeviceLine(k, value mem.RO, valueBytes []byte) error { - switch { - case k.EqualString("private_key"): - // wireguard-go guarantees not to send zero value; private keys are already clamped. - var err error - cfg.PrivateKey, err = key.ParseNodePrivateUntyped(value) - if err != nil { - return err - } - case k.EqualString("listen_port") || k.EqualString("fwmark"): - // ignore - default: - return fmt.Errorf("unexpected IpcGetOperation key: %q", k.StringCopy()) - } - return nil -} - -func (cfg *Config) handlePublicKeyLine(valueBytes []byte) (*Peer, error) { - p := Peer{} - var err error - p.PublicKey, err = key.ParseNodePublicUntyped(mem.B(valueBytes)) - if err != nil { - return nil, err - } - cfg.Peers = append(cfg.Peers, p) - return &cfg.Peers[len(cfg.Peers)-1], nil -} - -func (cfg *Config) handlePeerLine(peer *Peer, k, value mem.RO, valueBytes []byte) error { - switch { - case k.EqualString("endpoint"): - nk, err := key.ParseNodePublicUntyped(value) - if err != nil { - return fmt.Errorf("invalid endpoint %q for peer %q, expected a hex public key", value.StringCopy(), peer.PublicKey.ShortString()) - } - // nk ought to equal peer.PublicKey. - // Under some rare circumstances, it might not. See corp issue #3016. - // Even if that happens, don't stop early, so that we can recover from it. - // Instead, note the value of nk so we can fix as needed. - peer.WGEndpoint = nk - case k.EqualString("persistent_keepalive_interval"): - n, err := mem.ParseUint(value, 10, 16) - if err != nil { - return err - } - peer.PersistentKeepalive = uint16(n) - case k.EqualString("allowed_ip"): - ipp := netip.Prefix{} - err := ipp.UnmarshalText(valueBytes) - if err != nil { - return err - } - peer.AllowedIPs = append(peer.AllowedIPs, ipp) - case k.EqualString("protocol_version"): - if !value.EqualString("1") { - return fmt.Errorf("invalid protocol version: %q", value.StringCopy()) - } - case k.EqualString("replace_allowed_ips") || - k.EqualString("preshared_key") || - k.EqualString("last_handshake_time_sec") || - k.EqualString("last_handshake_time_nsec") || - k.EqualString("tx_bytes") || - k.EqualString("rx_bytes"): - // ignore - default: - return fmt.Errorf("unexpected IpcGetOperation key: %q", k.StringCopy()) - } - return nil -} diff --git a/wgengine/wgcfg/parser_test.go b/wgengine/wgcfg/parser_test.go deleted file mode 100644 index 8c38ec0251b21..0000000000000 --- a/wgengine/wgcfg/parser_test.go +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -package wgcfg - -import ( - "bufio" - "bytes" - "io" - "net/netip" - "reflect" - "runtime" - "testing" - - "tailscale.com/types/key" -) - -func noError(t *testing.T, err error) bool { - if err == nil { - return true - } - _, fn, line, _ := runtime.Caller(1) - t.Errorf("Error at %s:%d: %#v", fn, line, err) - return false -} - -func equal(t *testing.T, expected, actual any) bool { - if reflect.DeepEqual(expected, actual) { - return true - } - _, fn, line, _ := runtime.Caller(1) - t.Errorf("Failed equals at %s:%d\nactual %#v\nexpected %#v", fn, line, actual, expected) - return false -} - -func TestParseEndpoint(t *testing.T) { - _, _, err := parseEndpoint("[192.168.42.0:]:51880") - if err == nil { - t.Error("Error was expected") - } - host, port, err := parseEndpoint("192.168.42.0:51880") - if noError(t, err) { - equal(t, "192.168.42.0", host) - equal(t, uint16(51880), port) - } - host, port, err = parseEndpoint("test.wireguard.com:18981") - if noError(t, err) { - equal(t, "test.wireguard.com", host) - equal(t, uint16(18981), port) - } - host, port, err = parseEndpoint("[2607:5300:60:6b0::c05f:543]:2468") - if noError(t, err) { - equal(t, "2607:5300:60:6b0::c05f:543", host) - equal(t, uint16(2468), port) - } - _, _, err = parseEndpoint("[::::::invalid:18981") - if err == nil { - t.Error("Error was expected") - } -} - -func BenchmarkFromUAPI(b *testing.B) { - newK := func() (key.NodePublic, key.NodePrivate) { - b.Helper() - k := key.NewNode() - return k.Public(), k - } - k1, pk1 := newK() - ip1 := netip.MustParsePrefix("10.0.0.1/32") - - peer := Peer{ - PublicKey: k1, - AllowedIPs: []netip.Prefix{ip1}, - } - cfg1 := &Config{ - PrivateKey: pk1, - Peers: []Peer{peer, peer, peer, peer}, - } - - buf := new(bytes.Buffer) - w := bufio.NewWriter(buf) - if err := cfg1.ToUAPI(b.Logf, w, &Config{}); err != nil { - b.Fatal(err) - } - w.Flush() - r := bytes.NewReader(buf.Bytes()) - b.ReportAllocs() - for range b.N { - r.Seek(0, io.SeekStart) - _, err := FromUAPI(r) - if err != nil { - b.Errorf("failed from UAPI: %v", err) - } - } -} diff --git a/wgengine/wgcfg/wgcfg_clone.go b/wgengine/wgcfg/wgcfg_clone.go index 5c771a2288fce..a8a2122678b02 100644 --- a/wgengine/wgcfg/wgcfg_clone.go +++ b/wgengine/wgcfg/wgcfg_clone.go @@ -10,7 +10,6 @@ import ( "tailscale.com/types/key" "tailscale.com/types/logid" - "tailscale.com/types/ptr" ) // Clone makes a deep copy of Config. @@ -56,10 +55,10 @@ func (src *Peer) Clone() *Peer { *dst = *src dst.AllowedIPs = append(src.AllowedIPs[:0:0], src.AllowedIPs...) if dst.V4MasqAddr != nil { - dst.V4MasqAddr = ptr.To(*src.V4MasqAddr) + dst.V4MasqAddr = new(*src.V4MasqAddr) } if dst.V6MasqAddr != nil { - dst.V6MasqAddr = ptr.To(*src.V6MasqAddr) + dst.V6MasqAddr = new(*src.V6MasqAddr) } return dst } @@ -73,5 +72,4 @@ var _PeerCloneNeedsRegeneration = Peer(struct { V6MasqAddr *netip.Addr IsJailed bool PersistentKeepalive uint16 - WGEndpoint key.NodePublic }{}) diff --git a/wgengine/wgcfg/writer.go b/wgengine/wgcfg/writer.go deleted file mode 100644 index f4981e3e9185b..0000000000000 --- a/wgengine/wgcfg/writer.go +++ /dev/null @@ -1,154 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -package wgcfg - -import ( - "fmt" - "io" - "net/netip" - "strconv" - - "tailscale.com/types/key" - "tailscale.com/types/logger" -) - -// ToUAPI writes cfg in UAPI format to w. -// Prev is the previous device Config. -// -// Prev is required so that we can remove now-defunct peers without having to -// remove and re-add all peers, and so that we can avoid writing information -// about peers that have not changed since the previous time we wrote our -// Config. -func (cfg *Config) ToUAPI(logf logger.Logf, w io.Writer, prev *Config) error { - var stickyErr error - set := func(key, value string) { - if stickyErr != nil { - return - } - _, err := fmt.Fprintf(w, "%s=%s\n", key, value) - if err != nil { - stickyErr = err - } - } - setUint16 := func(key string, value uint16) { - set(key, strconv.FormatUint(uint64(value), 10)) - } - setPeer := func(peer Peer) { - set("public_key", peer.PublicKey.UntypedHexString()) - } - - // Device config. - if !prev.PrivateKey.Equal(cfg.PrivateKey) { - set("private_key", cfg.PrivateKey.UntypedHexString()) - } - - old := make(map[key.NodePublic]Peer) - for _, p := range prev.Peers { - old[p.PublicKey] = p - } - - // Add/configure all new peers. - for _, p := range cfg.Peers { - oldPeer, wasPresent := old[p.PublicKey] - - // We only want to write the peer header/version if we're about - // to change something about that peer, or if it's a new peer. - // Figure out up-front whether we'll need to do anything for - // this peer, and skip doing anything if not. - // - // If the peer was not present in the previous config, this - // implies that this is a new peer; set all of these to 'true' - // to ensure that we're writing the full peer configuration. - willSetEndpoint := oldPeer.WGEndpoint != p.PublicKey || !wasPresent - willChangeIPs := !cidrsEqual(oldPeer.AllowedIPs, p.AllowedIPs) || !wasPresent - willChangeKeepalive := oldPeer.PersistentKeepalive != p.PersistentKeepalive // if not wasPresent, no need to redundantly set zero (default) - - if !willSetEndpoint && !willChangeIPs && !willChangeKeepalive { - // It's safe to skip doing anything here; wireguard-go - // will not remove a peer if it's unspecified unless we - // tell it to (which we do below if necessary). - continue - } - - setPeer(p) - set("protocol_version", "1") - - // Avoid setting endpoints if the correct one is already known - // to WireGuard, because doing so generates a bit more work in - // calling magicsock's ParseEndpoint for effectively a no-op. - if willSetEndpoint { - if wasPresent { - // We had an endpoint, and it was wrong. - // By construction, this should not happen. - // If it does, keep going so that we can recover from it, - // but log so that we know about it, - // because it is an indicator of other failed invariants. - // See corp issue 3016. - logf("[unexpected] endpoint changed from %s to %s", oldPeer.WGEndpoint, p.PublicKey) - } - set("endpoint", p.PublicKey.UntypedHexString()) - } - - // TODO: replace_allowed_ips is expensive. - // If p.AllowedIPs is a strict superset of oldPeer.AllowedIPs, - // then skip replace_allowed_ips and instead add only - // the new ipps with allowed_ip. - if willChangeIPs { - set("replace_allowed_ips", "true") - for _, ipp := range p.AllowedIPs { - set("allowed_ip", ipp.String()) - } - } - - // Set PersistentKeepalive after the peer is otherwise configured, - // because it can trigger handshake packets. - if willChangeKeepalive { - setUint16("persistent_keepalive_interval", p.PersistentKeepalive) - } - } - - // Remove peers that were present but should no longer be. - for _, p := range cfg.Peers { - delete(old, p.PublicKey) - } - for _, p := range old { - setPeer(p) - set("remove", "true") - } - - if stickyErr != nil { - stickyErr = fmt.Errorf("ToUAPI: %w", stickyErr) - } - return stickyErr -} - -func cidrsEqual(x, y []netip.Prefix) bool { - // TODO: re-implement using netaddr.IPSet.Equal. - if len(x) != len(y) { - return false - } - // First see if they're equal in order, without allocating. - exact := true - for i := range x { - if x[i] != y[i] { - exact = false - break - } - } - if exact { - return true - } - - // Otherwise, see if they're the same, but out of order. - m := make(map[netip.Prefix]bool) - for _, v := range x { - m[v] = true - } - for _, v := range y { - if !m[v] { - return false - } - } - return true -} diff --git a/wgengine/wgengine.go b/wgengine/wgengine.go index 9dd782e4ab44f..5ca4b75cfa110 100644 --- a/wgengine/wgengine.go +++ b/wgengine/wgengine.go @@ -137,4 +137,8 @@ type Engine interface { // packets traversing the data path. The hook can be uninstalled by // calling this function with a nil value. InstallCaptureHook(packet.CaptureCallback) + + // SetPeerByIPPacketFunc installs a callback used by wireguard-go to + // look up which peer should handle an outbound packet by destination IP. + SetPeerByIPPacketFunc(func(netip.Addr) (_ key.NodePublic, ok bool)) } diff --git a/wif/wif.go b/wif/wif.go index bb2e760f2c7b7..af36ea980ea5c 100644 --- a/wif/wif.go +++ b/wif/wif.go @@ -40,7 +40,8 @@ const ( // 1. GitHub Actions (strongest env signals; may run atop any cloud) // 2. AWS via IMDSv2 token endpoint (does not require env vars) // 3. GCP via metadata header semantics -// 4. Azure via metadata endpoint +// 4. AWS ECS via ECS token endpoint and env vars provided by ECS +// 5. Azure via metadata endpoint func ObtainProviderToken(ctx context.Context, audience string) (string, error) { env := detectEnvironment(ctx) @@ -69,6 +70,9 @@ func detectEnvironment(ctx context.Context) Environment { if detectGCPMetadata(ctx, client) { return EnvGCP } + if os.Getenv("ECS_CONTAINER_METADATA_URI_V4") != "" { + return EnvAWS + } return EnvNone } @@ -163,7 +167,7 @@ func acquireGitHubActionsIDToken(ctx context.Context, audience string) (string, } func acquireAWSWebIdentityToken(ctx context.Context, audience string) (string, error) { - // LoadDefaultConfig wires up the default credential chain (incl. IMDS). + // LoadDefaultConfig wires up the default credential chain (incl. IMDS and ECS metadata). cfg, err := config.LoadDefaultConfig(ctx) if err != nil { return "", fmt.Errorf("load aws config: %w", err) @@ -174,12 +178,15 @@ func acquireAWSWebIdentityToken(ctx context.Context, audience string) (string, e return "", fmt.Errorf("AWS credentials unavailable (instance profile/IMDS?): %w", err) } - imdsClient := imds.NewFromConfig(cfg) - region, err := imdsClient.GetRegion(ctx, &imds.GetRegionInput{}) - if err != nil { - return "", fmt.Errorf("couldn't get AWS region: %w", err) + // ECS does not have IMDS; region must come from AWS_DEFAULT_REGION or AWS_REGION, + if cfg.Region == "" { + imdsClient := imds.NewFromConfig(cfg) + region, err := imdsClient.GetRegion(ctx, &imds.GetRegionInput{}) + if err != nil { + return "", fmt.Errorf("couldn't get AWS region: %w", err) + } + cfg.Region = region.Region } - cfg.Region = region.Region stsClient := sts.NewFromConfig(cfg) in := &sts.GetWebIdentityTokenInput{ @@ -190,8 +197,7 @@ func acquireAWSWebIdentityToken(ctx context.Context, audience string) (string, e out, err := stsClient.GetWebIdentityToken(ctx, in) if err != nil { - var apiErr smithy.APIError - if errors.As(err, &apiErr) { + if apiErr, ok := errors.AsType[smithy.APIError](err); ok { return "", fmt.Errorf("aws sts:GetWebIdentityToken failed (%s): %w", apiErr.ErrorCode(), err) } return "", fmt.Errorf("aws sts:GetWebIdentityToken failed: %w", err) diff --git a/words/scales.txt b/words/scales.txt index ce749b9dcc368..831011deedae0 100644 --- a/words/scales.txt +++ b/words/scales.txt @@ -6,6 +6,7 @@ catfish bass salmon tuna +fish hammerhead eel carp