From 43e38f48706cfec6d9e7c40cae62af81c72c43d8 Mon Sep 17 00:00:00 2001 From: Ben Hoyt Date: Wed, 20 Aug 2025 14:04:53 +1200 Subject: [PATCH 1/2] refactor: use Go 1.22's ServeMux path patterns instead gorilla/mux --- HACKING.md | 1 - docs/explanation/api-and-clients.md | 2 +- docs/specs/openapi.yaml | 12 ++++----- go.mod | 1 - go.sum | 2 -- internals/daemon/api.go | 6 +---- internals/daemon/api_changes.go | 12 ++++----- internals/daemon/api_changes_test.go | 8 +++--- internals/daemon/api_notices.go | 2 +- internals/daemon/api_notices_test.go | 12 ++++----- internals/daemon/api_tasks.go | 5 ++-- internals/daemon/api_test.go | 10 -------- internals/daemon/daemon.go | 20 ++++----------- internals/daemon/daemon_test.go | 38 ++++++++-------------------- internals/daemon/export_test.go | 9 ------- 15 files changed, 42 insertions(+), 98 deletions(-) diff --git a/HACKING.md b/HACKING.md index 3a6f7baa2..036d9f549 100644 --- a/HACKING.md +++ b/HACKING.md @@ -133,7 +133,6 @@ import ( "net" "os" - "github.com/gorilla/mux" . "gopkg.in/check.v1" "github.com/canonical/pebble/internals/systemd" diff --git a/docs/explanation/api-and-clients.md b/docs/explanation/api-and-clients.md index c15bdf3e3..5bf569b4b 100644 --- a/docs/explanation/api-and-clients.md +++ b/docs/explanation/api-and-clients.md @@ -30,7 +30,7 @@ API endpoints fall into one of four access levels, from least restricted to most * **Admin-access** - Only allowed from admin users. For example, adding a layer or starting a service. * `GET /v1/files`, which pulls a file from a remote system - * `GET /v1/tasks/{task-id}/websocket/{websocket-id}` + * `GET /v1/tasks/{taskID}/websocket/{websocketID}` * All `POST` endpoints except `POST /v1/notices` (which is read-access) Pebble authenticates clients that connect to the socket API using peer credentials ([`SO_PEERCRED`](https://man7.org/linux/man-pages/man7/socket.7.html)) to determine the user ID (UID) of the connecting process. If this UID is 0 (root) or the UID of the Pebble daemon, the user's access level is `admin`, otherwise the access level is `read`. diff --git a/docs/specs/openapi.yaml b/docs/specs/openapi.yaml index 5bf011722..dab85d687 100644 --- a/docs/specs/openapi.yaml +++ b/docs/specs/openapi.yaml @@ -243,7 +243,7 @@ paths: "ready-time": "2024-12-27T12:31:27.686371869+08:00" } } - /v1/tasks/{task-id}/websocket/{websocket-id}: + /v1/tasks/{taskID}/websocket/{websocketID}: get: summary: Connect to a task's websocket tags: @@ -251,13 +251,13 @@ paths: description: Establish a websocket connection to a specific task. parameters: - in: path - name: task-id + name: taskID schema: type: string required: true description: The ID of the task. - in: path - name: websocket-id + name: websocketID schema: type: string required: true @@ -401,9 +401,9 @@ paths: description: | Start a command with the given options and return a value representing the process. - This API returns a `task-id` (see the response schema and the example below), - then you need to call `/v1/tasks/{task-id}/websocket/control` and `/v1/tasks/{task-id}/websocket/stdio` - (also `/v1/tasks/{task-id}/websocket/stderr` if `split-stderr` is true) with the returned `task-id`. + This API returns a task ID (see the response schema and the example below), + then you need to call `/v1/tasks/{taskID}/websocket/control` and `/v1/tasks/{taskID}/websocket/stdio` + (also `/v1/tasks/{taskID}/websocket/stderr` if `split-stderr` is true) with the returned task ID. requestBody: required: true content: diff --git a/go.mod b/go.mod index 4c8a68532..e1b5530dd 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,6 @@ require ( github.com/GehirnInc/crypt v0.0.0-20230320061759-8cc1b52080c5 github.com/canonical/go-flags v0.0.0-20230403090104-105d09a091b8 github.com/canonical/x-go v0.0.0-20230522092633-7947a7587f5b - github.com/gorilla/mux v1.8.1 github.com/gorilla/websocket v1.5.1 github.com/pkg/term v1.1.0 golang.org/x/sys v0.33.0 diff --git a/go.sum b/go.sum index 427f4f6fa..204146f83 100644 --- a/go.sum +++ b/go.sum @@ -7,8 +7,6 @@ github.com/canonical/x-go v0.0.0-20230522092633-7947a7587f5b/go.mod h1:upTK9n6rl github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= -github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= diff --git a/internals/daemon/api.go b/internals/daemon/api.go index dec790467..82ef5f29b 100644 --- a/internals/daemon/api.go +++ b/internals/daemon/api.go @@ -18,8 +18,6 @@ import ( "net/http" "strconv" - "github.com/gorilla/mux" - "github.com/canonical/pebble/internals/overlord" "github.com/canonical/pebble/internals/overlord/restart" "github.com/canonical/pebble/internals/overlord/state" @@ -82,7 +80,7 @@ var API = []*Command{{ WriteAccess: AdminAccess{}, POST: v1PostExec, }, { - Path: "/v1/tasks/{task-id}/websocket/{websocket-id}", + Path: "/v1/tasks/{taskID}/websocket/{websocketID}", ReadAccess: AdminAccess{}, // used by exec, so require admin GET: v1GetTaskWebsocket, }, { @@ -127,8 +125,6 @@ var ( overlordServiceManager = (*overlord.Overlord).ServiceManager overlordPlanManager = (*overlord.Overlord).PlanManager overlordCheckManager = (*overlord.Overlord).CheckManager - - muxVars = mux.Vars ) func v1SystemInfo(c *Command, r *http.Request, _ *UserState) Response { diff --git a/internals/daemon/api_changes.go b/internals/daemon/api_changes.go index 962833b19..b3c6991ad 100644 --- a/internals/daemon/api_changes.go +++ b/internals/daemon/api_changes.go @@ -171,7 +171,7 @@ func v1GetChanges(c *Command, r *http.Request, _ *UserState) Response { } func v1GetChange(c *Command, r *http.Request, _ *UserState) Response { - changeID := muxVars(r)["id"] + changeID := r.PathValue("id") st := c.d.overlord.State() st.Lock() defer st.Unlock() @@ -184,7 +184,7 @@ func v1GetChange(c *Command, r *http.Request, _ *UserState) Response { } func v1GetChangeWait(c *Command, r *http.Request, _ *UserState) Response { - changeID := muxVars(r)["id"] + changeID := r.PathValue("id") st := c.d.overlord.State() st.Lock() change := st.Change(changeID) @@ -224,13 +224,13 @@ func v1GetChangeWait(c *Command, r *http.Request, _ *UserState) Response { } func v1PostChange(c *Command, r *http.Request, _ *UserState) Response { - chID := muxVars(r)["id"] + changeID := r.PathValue("id") state := c.d.overlord.State() state.Lock() defer state.Unlock() - chg := state.Change(chID) + chg := state.Change(changeID) if chg == nil { - return NotFound("cannot find change with id %q", chID) + return NotFound("cannot find change with id %q", changeID) } var reqData struct { @@ -247,7 +247,7 @@ func v1PostChange(c *Command, r *http.Request, _ *UserState) Response { } if chg.Status().Ready() { - return BadRequest("cannot abort change %s with nothing pending", chID) + return BadRequest("cannot abort change %s with nothing pending", changeID) } // flag the change diff --git a/internals/daemon/api_changes_test.go b/internals/daemon/api_changes_test.go index 1eb3ba2e4..36cdf7eb4 100644 --- a/internals/daemon/api_changes_test.go +++ b/internals/daemon/api_changes_test.go @@ -205,13 +205,13 @@ func (s *apiSuite) TestStateChange(c *check.C) { task := chg.Tasks()[0] task.Set("api-data", map[string]string{"foo": "bar"}) st.Unlock() - s.vars = map[string]string{"id": ids[0]} stateChangeCmd := apiCmd("/v1/changes/{id}") // Execute req, err := http.NewRequest("GET", "/v1/change/"+ids[0], nil) c.Assert(err, check.IsNil) + req.SetPathValue("id", ids[0]) rsp := v1GetChange(stateChangeCmd, req, nil).(*resp) rec := httptest.NewRecorder() rsp.ServeHTTP(rec, req) @@ -276,7 +276,6 @@ func (s *apiSuite) TestStateChangeAbort(c *check.C) { st.Lock() ids := setupChanges(st) st.Unlock() - s.vars = map[string]string{"id": ids[0]} buf := bytes.NewBufferString(`{"action": "abort"}`) @@ -285,6 +284,7 @@ func (s *apiSuite) TestStateChangeAbort(c *check.C) { // Execute req, err := http.NewRequest("POST", "/v1/changes/"+ids[0], buf) c.Assert(err, check.IsNil) + req.SetPathValue("id", ids[0]) rsp := v1PostChange(stateChangeCmd, req, nil).(*resp) rec := httptest.NewRecorder() rsp.ServeHTTP(rec, req) @@ -344,7 +344,6 @@ func (s *apiSuite) TestStateChangeAbortIsReady(c *check.C) { ids := setupChanges(st) st.Change(ids[0]).SetStatus(state.DoneStatus) st.Unlock() - s.vars = map[string]string{"id": ids[0]} buf := bytes.NewBufferString(`{"action": "abort"}`) @@ -353,6 +352,7 @@ func (s *apiSuite) TestStateChangeAbortIsReady(c *check.C) { // Execute req, err := http.NewRequest("POST", "/v1/changes/"+ids[0], buf) c.Assert(err, check.IsNil) + req.SetPathValue("id", ids[0]) rsp := v1PostChange(stateChangeCmd, req, nil).(*resp) rec := httptest.NewRecorder() rsp.ServeHTTP(rec, req) @@ -459,9 +459,9 @@ func (s *apiSuite) testWaitChange(ctx context.Context, c *check.C, query string, } // Execute - s.vars = map[string]string{"id": change.ID()} req, err := http.NewRequestWithContext(ctx, "GET", "/v1/changes/"+change.ID()+"/wait"+query, nil) c.Assert(err, check.IsNil) + req.SetPathValue("id", change.ID()) rsp := v1GetChangeWait(apiCmd("/v1/changes/{id}/wait"), req, nil).(*resp) rec := httptest.NewRecorder() rsp.ServeHTTP(rec, req) diff --git a/internals/daemon/api_notices.go b/internals/daemon/api_notices.go index 8cd274508..a5e6f5a69 100644 --- a/internals/daemon/api_notices.go +++ b/internals/daemon/api_notices.go @@ -243,7 +243,7 @@ func v1GetNotice(c *Command, r *http.Request, user *UserState) Response { if user == nil || user.UID == nil { return Forbidden("cannot determine UID of request, so cannot retrieve notice") } - noticeID := muxVars(r)["id"] + noticeID := r.PathValue("id") st := c.d.overlord.State() st.Lock() defer st.Unlock() diff --git a/internals/daemon/api_notices_test.go b/internals/daemon/api_notices_test.go index 607e9f869..02e69a1a1 100644 --- a/internals/daemon/api_notices_test.go +++ b/internals/daemon/api_notices_test.go @@ -764,7 +764,7 @@ func (s *apiSuite) TestNotice(c *C) { req, err := http.NewRequest("GET", "/v1/notices/"+noticeIDPublic, nil) c.Assert(err, IsNil) noticesCmd := apiCmd("/v1/notices/{id}") - s.vars = map[string]string{"id": noticeIDPublic} + req.SetPathValue("id", noticeIDPublic) rsp, ok := noticesCmd.GET(noticesCmd, req, userState(state.ReadAccess, 1000)).(*resp) c.Assert(ok, Equals, true) @@ -780,7 +780,7 @@ func (s *apiSuite) TestNotice(c *C) { req, err = http.NewRequest("GET", "/v1/notices/"+noticeIDPrivate, nil) c.Assert(err, IsNil) noticesCmd = apiCmd("/v1/notices/{id}") - s.vars = map[string]string{"id": noticeIDPrivate} + req.SetPathValue("id", noticeIDPrivate) rsp, ok = noticesCmd.GET(noticesCmd, req, userState(state.ReadAccess, 1000)).(*resp) c.Assert(ok, Equals, true) @@ -800,7 +800,7 @@ func (s *apiSuite) TestNoticeNotFound(c *C) { req, err := http.NewRequest("GET", "/v1/notices/1234", nil) c.Assert(err, IsNil) noticesCmd := apiCmd("/v1/notices/{id}") - s.vars = map[string]string{"id": "1234"} + req.SetPathValue("id", "1234") rsp, ok := noticesCmd.GET(noticesCmd, req, userState(state.ReadAccess, 1000)).(*resp) c.Assert(ok, Equals, true) @@ -816,7 +816,7 @@ func (s *apiSuite) TestNoticeUnknownRequestUID(c *C) { req, err := http.NewRequest("GET", "/v1/notices/1234", nil) c.Assert(err, IsNil) noticesCmd := apiCmd("/v1/notices/{id}") - s.vars = map[string]string{"id": "1234"} + req.SetPathValue("id", "1234") rsp, ok := noticesCmd.GET(noticesCmd, req, &UserState{Access: state.ReadAccess}).(*resp) c.Assert(ok, Equals, true) @@ -839,7 +839,7 @@ func (s *apiSuite) TestNoticeAdminAllowed(c *C) { req, err := http.NewRequest("GET", "/v1/notices/"+noticeID, nil) c.Assert(err, IsNil) noticesCmd := apiCmd("/v1/notices/{id}") - s.vars = map[string]string{"id": noticeID} + req.SetPathValue("id", noticeID) rsp, ok := noticesCmd.GET(noticesCmd, req, userState(state.AdminAccess, 0)).(*resp) c.Assert(ok, Equals, true) @@ -866,7 +866,7 @@ func (s *apiSuite) TestNoticeNonAdminNotAllowed(c *C) { req, err := http.NewRequest("GET", "/v1/notices/"+noticeID, nil) c.Assert(err, IsNil) noticesCmd := apiCmd("/v1/notices/{id}") - s.vars = map[string]string{"id": noticeID} + req.SetPathValue("id", noticeID) rsp, ok := noticesCmd.GET(noticesCmd, req, userState(state.ReadAccess, 1001)).(*resp) c.Assert(ok, Equals, true) diff --git a/internals/daemon/api_tasks.go b/internals/daemon/api_tasks.go index 59c16c586..7e7672919 100644 --- a/internals/daemon/api_tasks.go +++ b/internals/daemon/api_tasks.go @@ -24,9 +24,8 @@ import ( ) func v1GetTaskWebsocket(c *Command, req *http.Request, _ *UserState) Response { - vars := muxVars(req) - taskID := vars["task-id"] - websocketID := vars["websocket-id"] + taskID := req.PathValue("taskID") + websocketID := req.PathValue("websocketID") st := c.d.overlord.State() st.Lock() diff --git a/internals/daemon/api_test.go b/internals/daemon/api_test.go index bcbdddec6..b5012fc82 100644 --- a/internals/daemon/api_test.go +++ b/internals/daemon/api_test.go @@ -16,7 +16,6 @@ package daemon import ( "encoding/json" - "net/http" "net/http/httptest" "os" @@ -33,9 +32,6 @@ type apiSuite struct { pebbleDir string - vars map[string]string - - restoreMuxVars func() overlordStarted bool } @@ -45,7 +41,6 @@ func (s *apiSuite) SetUpTest(c *check.C) { c.Fatalf("cannot start reaper: %v", err) } - s.restoreMuxVars = FakeMuxVars(s.muxVars) s.pebbleDir = c.MkDir() } @@ -56,7 +51,6 @@ func (s *apiSuite) TearDownTest(c *check.C) { } s.d = nil s.pebbleDir = "" - s.restoreMuxVars() err := reaper.Stop() if err != nil { @@ -64,10 +58,6 @@ func (s *apiSuite) TearDownTest(c *check.C) { } } -func (s *apiSuite) muxVars(*http.Request) map[string]string { - return s.vars -} - func (s *apiSuite) daemon(c *check.C) *Daemon { if s.d != nil { panic("called daemon() twice") diff --git a/internals/daemon/daemon.go b/internals/daemon/daemon.go index 9167f28b3..6c07902ef 100644 --- a/internals/daemon/daemon.go +++ b/internals/daemon/daemon.go @@ -34,7 +34,6 @@ import ( "syscall" "time" - "github.com/gorilla/mux" "gopkg.in/tomb.v2" "github.com/canonical/pebble/internals/logger" @@ -177,7 +176,7 @@ type Daemon struct { connTracker *connTracker serve *http.Server tomb tomb.Tomb - router *mux.Router + router *http.ServeMux standbyOpinions *standby.StandbyOpinions // set to what kind of restart was requested (if any) @@ -203,9 +202,8 @@ type ResponseFunc func(*Command, *http.Request, *UserState) Response // A Command routes a request to an individual per-verb ResponseFUnc type Command struct { - Path string - PathPrefix string - // + Path string + GET ResponseFunc PUT ResponseFunc POST ResponseFunc @@ -488,20 +486,12 @@ func (d *Daemon) SetDegradedMode(err error) { } func (d *Daemon) addRoutes() { - d.router = mux.NewRouter() + d.router = http.NewServeMux() for _, c := range API { c.d = d - if c.PathPrefix == "" { - d.router.Handle(c.Path, c).Name(c.Path) - } else { - d.router.PathPrefix(c.PathPrefix).Handler(c).Name(c.PathPrefix) - } + d.router.Handle(c.Path, c) } - - // also maybe add a /favicon.ico handler... - - d.router.NotFoundHandler = NotFound("invalid API endpoint requested") } type connTracker struct { diff --git a/internals/daemon/daemon_test.go b/internals/daemon/daemon_test.go index 12f1297f8..f691cfed3 100644 --- a/internals/daemon/daemon_test.go +++ b/internals/daemon/daemon_test.go @@ -39,7 +39,6 @@ import ( "testing" "time" - "github.com/gorilla/mux" . "gopkg.in/check.v1" "github.com/canonical/pebble/cmd" @@ -231,12 +230,16 @@ func (s *daemonSuite) TestAddCommand(c *C) { }() d := s.newDaemon(c) - d.Init() - c.Assert(d.Start(), IsNil) - defer d.Stop(nil) - result := d.router.Get(endpoint).GetHandler() - c.Assert(result, Equals, &command) + ctx := context.WithValue(context.Background(), TransportTypeKey{}, TransportTypeUnixSocket) + req, err := http.NewRequestWithContext(ctx, "GET", endpoint, nil) + c.Assert(err, IsNil) + + rec := httptest.NewRecorder() + d.router.ServeHTTP(rec, req) + c.Assert(handler.cmd, Equals, &command) + c.Assert(handler.lastMethod, Equals, "GET") + c.Assert(rec.Code, Equals, http.StatusOK) } func (s *daemonSuite) TestExplicitPaths(c *C) { @@ -623,27 +626,6 @@ func (s *daemonSuite) TestDefaultUcredUsers(c *C) { c.Check(*userSeen.UID, Equals, uint32(os.Getuid()+1)) } -func (s *daemonSuite) TestAddRoutes(c *C) { - d := s.newDaemon(c) - - expected := make([]string, len(API)) - for i, v := range API { - if v.PathPrefix != "" { - expected[i] = v.PathPrefix - continue - } - expected[i] = v.Path - } - - got := make([]string, 0, len(API)) - c.Assert(d.router.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error { - got = append(got, route.GetName()) - return nil - }), IsNil) - - c.Check(got, DeepEquals, expected) // this'll stop being true if routes are added that aren't commands (e.g. for the favicon) -} - type witnessAcceptListener struct { net.Listener @@ -1566,7 +1548,7 @@ func (s *daemonSuite) TestWritesRequireAdminAccess(c *C) { } // Task websockets (GET) is used for exec, so requires admin access too. - cmd = apiCmd("/v1/tasks/{task-id}/websocket/{websocket-id}") + cmd = apiCmd("/v1/tasks/{taskID}/websocket/{websocketID}") switch cmd.ReadAccess.(type) { case OpenAccess, UserAccess: c.Errorf("%s ReadAccess should be AdminAccess, not %T", cmd.Path, cmd.WriteAccess) diff --git a/internals/daemon/export_test.go b/internals/daemon/export_test.go index ed41b4ace..d06d2229a 100644 --- a/internals/daemon/export_test.go +++ b/internals/daemon/export_test.go @@ -15,7 +15,6 @@ package daemon import ( - "net/http" "time" "github.com/canonical/pebble/internals/overlord" @@ -23,14 +22,6 @@ import ( "github.com/canonical/pebble/internals/overlord/state" ) -func FakeMuxVars(f func(*http.Request) map[string]string) (restore func()) { - old := muxVars - muxVars = f - return func() { - muxVars = old - } -} - func FakeStateEnsureBefore(f func(st *state.State, d time.Duration)) (restore func()) { old := stateEnsureBefore stateEnsureBefore = f From 03f25f5ea3ca5b65d64c921e1d3daaf19580fadc Mon Sep 17 00:00:00 2001 From: Ben Hoyt Date: Mon, 16 Feb 2026 16:42:30 +1300 Subject: [PATCH 2/2] reinstate not-found handler for unknown routes --- internals/daemon/daemon.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/internals/daemon/daemon.go b/internals/daemon/daemon.go index 9c66f330e..12a85c3f9 100644 --- a/internals/daemon/daemon.go +++ b/internals/daemon/daemon.go @@ -521,6 +521,8 @@ func (d *Daemon) addRoutes() { c.d = d d.router.Handle(c.Path, c) } + + d.router.Handle("/", NotFound("invalid API endpoint requested")) } type connTracker struct {