diff --git a/HACKING.md b/HACKING.md index 7f3c89be2..8633d1ca7 100644 --- a/HACKING.md +++ b/HACKING.md @@ -133,7 +133,6 @@ import ( "net" "os" - "github.com/gorilla/mux" . "gopkg.in/check.v1" "github.com/canonical/pebble/internals/systemd" diff --git a/docs/explanation/api-and-clients.md b/docs/explanation/api-and-clients.md index 5047fe072..d39de593d 100644 --- a/docs/explanation/api-and-clients.md +++ b/docs/explanation/api-and-clients.md @@ -30,7 +30,7 @@ API endpoints fall into one of four access levels, from least restricted to most * **Admin-access** - Only allowed from admin users. For example, adding a layer or starting a service. * `GET /v1/files`, which pulls a file from a remote system - * `GET /v1/tasks/{task-id}/websocket/{websocket-id}` + * `GET /v1/tasks/{taskID}/websocket/{websocketID}` * All `POST` endpoints except `POST /v1/notices` (which is read-access) Pebble authenticates clients that connect to the socket API using peer credentials ([`SO_PEERCRED`](https://man7.org/linux/man-pages/man7/socket.7.html)) to determine the user ID (UID) of the connecting process. If this UID is 0 (root) or the UID of the Pebble daemon, the user's access level is `admin`, otherwise the access level is `read`. diff --git a/docs/specs/openapi.yaml b/docs/specs/openapi.yaml index 6fe3860ef..08e15a411 100644 --- a/docs/specs/openapi.yaml +++ b/docs/specs/openapi.yaml @@ -243,7 +243,7 @@ paths: "ready-time": "2024-12-27T12:31:27.686371869+08:00" } } - /v1/tasks/{task-id}/websocket/{websocket-id}: + /v1/tasks/{taskID}/websocket/{websocketID}: get: summary: Connect to a task's websocket tags: @@ -251,13 +251,13 @@ paths: description: Establish a websocket connection to a specific task. parameters: - in: path - name: task-id + name: taskID schema: type: string required: true description: The ID of the task. - in: path - name: websocket-id + name: websocketID schema: type: string required: true @@ -402,9 +402,9 @@ paths: description: | Start a command with the given options and return a value representing the process. - This API returns a `task-id` (see the response schema and the example below), - then you need to call `/v1/tasks/{task-id}/websocket/control` and `/v1/tasks/{task-id}/websocket/stdio` - (also `/v1/tasks/{task-id}/websocket/stderr` if `split-stderr` is true) with the returned `task-id`. + This API returns a task ID (see the response schema and the example below), + then you need to call `/v1/tasks/{taskID}/websocket/control` and `/v1/tasks/{taskID}/websocket/stdio` + (also `/v1/tasks/{taskID}/websocket/stderr` if `split-stderr` is true) with the returned task ID. requestBody: required: true content: diff --git a/go.mod b/go.mod index df4b09bc3..c0b695690 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,6 @@ require ( github.com/GehirnInc/crypt v0.0.0-20230320061759-8cc1b52080c5 github.com/canonical/go-flags v0.0.0-20230403090104-105d09a091b8 github.com/canonical/x-go v0.0.0-20230522092633-7947a7587f5b - github.com/gorilla/mux v1.8.1 github.com/gorilla/websocket v1.5.1 github.com/pkg/term v1.1.0 golang.org/x/sys v0.33.0 diff --git a/go.sum b/go.sum index 427f4f6fa..204146f83 100644 --- a/go.sum +++ b/go.sum @@ -7,8 +7,6 @@ github.com/canonical/x-go v0.0.0-20230522092633-7947a7587f5b/go.mod h1:upTK9n6rl github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= -github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= diff --git a/internals/daemon/api.go b/internals/daemon/api.go index 8f3376280..4635a1eaa 100644 --- a/internals/daemon/api.go +++ b/internals/daemon/api.go @@ -18,8 +18,6 @@ import ( "net/http" "strconv" - "github.com/gorilla/mux" - "github.com/canonical/pebble/internals/overlord" "github.com/canonical/pebble/internals/overlord/restart" "github.com/canonical/pebble/internals/overlord/state" @@ -82,7 +80,7 @@ var API = []*Command{{ WriteAccess: AdminAccess{}, POST: v1PostExec, }, { - Path: "/v1/tasks/{task-id}/websocket/{websocket-id}", + Path: "/v1/tasks/{taskID}/websocket/{websocketID}", ReadAccess: AdminAccess{}, // used by exec, so require admin GET: v1GetTaskWebsocket, }, { @@ -131,8 +129,6 @@ var ( overlordServiceManager = (*overlord.Overlord).ServiceManager overlordPlanManager = (*overlord.Overlord).PlanManager overlordCheckManager = (*overlord.Overlord).CheckManager - - muxVars = mux.Vars ) func v1SystemInfo(c *Command, r *http.Request, _ *UserState) Response { diff --git a/internals/daemon/api_changes.go b/internals/daemon/api_changes.go index 18bad8f47..553f8c314 100644 --- a/internals/daemon/api_changes.go +++ b/internals/daemon/api_changes.go @@ -166,7 +166,7 @@ func v1GetChanges(c *Command, r *http.Request, _ *UserState) Response { } func v1GetChange(c *Command, r *http.Request, _ *UserState) Response { - changeID := muxVars(r)["id"] + changeID := r.PathValue("id") st := c.d.overlord.State() st.Lock() defer st.Unlock() @@ -179,7 +179,7 @@ func v1GetChange(c *Command, r *http.Request, _ *UserState) Response { } func v1GetChangeWait(c *Command, r *http.Request, _ *UserState) Response { - changeID := muxVars(r)["id"] + changeID := r.PathValue("id") st := c.d.overlord.State() st.Lock() change := st.Change(changeID) @@ -219,13 +219,13 @@ func v1GetChangeWait(c *Command, r *http.Request, _ *UserState) Response { } func v1PostChange(c *Command, r *http.Request, _ *UserState) Response { - chID := muxVars(r)["id"] + changeID := r.PathValue("id") state := c.d.overlord.State() state.Lock() defer state.Unlock() - chg := state.Change(chID) + chg := state.Change(changeID) if chg == nil { - return NotFound("cannot find change with id %q", chID) + return NotFound("cannot find change with id %q", changeID) } var reqData struct { @@ -242,7 +242,7 @@ func v1PostChange(c *Command, r *http.Request, _ *UserState) Response { } if chg.Status().Ready() { - return BadRequest("cannot abort change %s with nothing pending", chID) + return BadRequest("cannot abort change %s with nothing pending", changeID) } // flag the change diff --git a/internals/daemon/api_changes_test.go b/internals/daemon/api_changes_test.go index ceb38358f..db42a99c9 100644 --- a/internals/daemon/api_changes_test.go +++ b/internals/daemon/api_changes_test.go @@ -205,13 +205,13 @@ func (s *apiSuite) TestStateChange(c *check.C) { task := chg.Tasks()[0] task.Set("api-data", map[string]string{"foo": "bar"}) st.Unlock() - s.vars = map[string]string{"id": ids[0]} stateChangeCmd := apiCmd("/v1/changes/{id}") // Execute req, err := http.NewRequest("GET", "/v1/change/"+ids[0], nil) c.Assert(err, check.IsNil) + req.SetPathValue("id", ids[0]) rsp := v1GetChange(stateChangeCmd, req, nil).(*resp) rec := httptest.NewRecorder() rsp.ServeHTTP(rec, req) @@ -276,7 +276,6 @@ func (s *apiSuite) TestStateChangeAbort(c *check.C) { st.Lock() ids := setupChanges(st) st.Unlock() - s.vars = map[string]string{"id": ids[0]} buf := bytes.NewBufferString(`{"action": "abort"}`) @@ -285,6 +284,7 @@ func (s *apiSuite) TestStateChangeAbort(c *check.C) { // Execute req, err := http.NewRequest("POST", "/v1/changes/"+ids[0], buf) c.Assert(err, check.IsNil) + req.SetPathValue("id", ids[0]) rsp := v1PostChange(stateChangeCmd, req, nil).(*resp) rec := httptest.NewRecorder() rsp.ServeHTTP(rec, req) @@ -344,7 +344,6 @@ func (s *apiSuite) TestStateChangeAbortIsReady(c *check.C) { ids := setupChanges(st) st.Change(ids[0]).SetStatus(state.DoneStatus) st.Unlock() - s.vars = map[string]string{"id": ids[0]} buf := bytes.NewBufferString(`{"action": "abort"}`) @@ -353,6 +352,7 @@ func (s *apiSuite) TestStateChangeAbortIsReady(c *check.C) { // Execute req, err := http.NewRequest("POST", "/v1/changes/"+ids[0], buf) c.Assert(err, check.IsNil) + req.SetPathValue("id", ids[0]) rsp := v1PostChange(stateChangeCmd, req, nil).(*resp) rec := httptest.NewRecorder() rsp.ServeHTTP(rec, req) @@ -459,9 +459,9 @@ func (s *apiSuite) testWaitChange(ctx context.Context, c *check.C, query string, } // Execute - s.vars = map[string]string{"id": change.ID()} req, err := http.NewRequestWithContext(ctx, "GET", "/v1/changes/"+change.ID()+"/wait"+query, nil) c.Assert(err, check.IsNil) + req.SetPathValue("id", change.ID()) rsp := v1GetChangeWait(apiCmd("/v1/changes/{id}/wait"), req, nil).(*resp) rec := httptest.NewRecorder() rsp.ServeHTTP(rec, req) diff --git a/internals/daemon/api_exec_test.go b/internals/daemon/api_exec_test.go index f63850909..eaa5159d3 100644 --- a/internals/daemon/api_exec_test.go +++ b/internals/daemon/api_exec_test.go @@ -471,15 +471,11 @@ func (s *execSuite) TestExecChangeReady(c *C) { taskID, ok := execResp.Result["task-id"].(string) c.Assert(ok, Equals, true) - vars := map[string]string{"task-id": taskID, "websocket-id": "control"} - restoreMuxVars := FakeMuxVars(func(*http.Request) map[string]string { - return vars - }) - defer restoreMuxVars() - - websocketCmd := apiCmd("/v1/tasks/{task-id}/websocket/{websocket-id}") + websocketCmd := apiCmd("/v1/tasks/{taskID}/websocket/{websocketID}") req, err := http.NewRequest("GET", fmt.Sprintf("/v1/tasks/%s/websocket/%s", taskID, "control"), nil) c.Assert(err, IsNil) + req.SetPathValue("taskID", taskID) + req.SetPathValue("websocketID", "control") rsp := v1GetTaskWebsocket(websocketCmd, req, nil).(websocketResponse) rec := httptest.NewRecorder() rsp.ServeHTTP(rec, req) diff --git a/internals/daemon/api_notices.go b/internals/daemon/api_notices.go index 2a4ccb496..666723e96 100644 --- a/internals/daemon/api_notices.go +++ b/internals/daemon/api_notices.go @@ -244,7 +244,7 @@ func v1GetNotice(c *Command, r *http.Request, user *UserState) Response { if user == nil || user.UID == nil { return Forbidden("cannot determine UID of request, so cannot retrieve notice") } - noticeID := muxVars(r)["id"] + noticeID := r.PathValue("id") st := c.d.overlord.State() st.Lock() defer st.Unlock() diff --git a/internals/daemon/api_notices_test.go b/internals/daemon/api_notices_test.go index 58162f5e6..5f4094a5a 100644 --- a/internals/daemon/api_notices_test.go +++ b/internals/daemon/api_notices_test.go @@ -765,7 +765,7 @@ func (s *apiSuite) TestNotice(c *C) { req, err := http.NewRequest("GET", "/v1/notices/"+noticeIDPublic, nil) c.Assert(err, IsNil) noticesCmd := apiCmd("/v1/notices/{id}") - s.vars = map[string]string{"id": noticeIDPublic} + req.SetPathValue("id", noticeIDPublic) rsp, ok := noticesCmd.GET(noticesCmd, req, userState(identities.ReadAccess, 1000)).(*resp) c.Assert(ok, Equals, true) @@ -781,7 +781,7 @@ func (s *apiSuite) TestNotice(c *C) { req, err = http.NewRequest("GET", "/v1/notices/"+noticeIDPrivate, nil) c.Assert(err, IsNil) noticesCmd = apiCmd("/v1/notices/{id}") - s.vars = map[string]string{"id": noticeIDPrivate} + req.SetPathValue("id", noticeIDPrivate) rsp, ok = noticesCmd.GET(noticesCmd, req, userState(identities.ReadAccess, 1000)).(*resp) c.Assert(ok, Equals, true) @@ -801,7 +801,7 @@ func (s *apiSuite) TestNoticeNotFound(c *C) { req, err := http.NewRequest("GET", "/v1/notices/1234", nil) c.Assert(err, IsNil) noticesCmd := apiCmd("/v1/notices/{id}") - s.vars = map[string]string{"id": "1234"} + req.SetPathValue("id", "1234") rsp, ok := noticesCmd.GET(noticesCmd, req, userState(identities.ReadAccess, 1000)).(*resp) c.Assert(ok, Equals, true) @@ -817,7 +817,7 @@ func (s *apiSuite) TestNoticeUnknownRequestUID(c *C) { req, err := http.NewRequest("GET", "/v1/notices/1234", nil) c.Assert(err, IsNil) noticesCmd := apiCmd("/v1/notices/{id}") - s.vars = map[string]string{"id": "1234"} + req.SetPathValue("id", "1234") rsp, ok := noticesCmd.GET(noticesCmd, req, &UserState{Access: identities.ReadAccess}).(*resp) c.Assert(ok, Equals, true) @@ -840,7 +840,7 @@ func (s *apiSuite) TestNoticeAdminAllowed(c *C) { req, err := http.NewRequest("GET", "/v1/notices/"+noticeID, nil) c.Assert(err, IsNil) noticesCmd := apiCmd("/v1/notices/{id}") - s.vars = map[string]string{"id": noticeID} + req.SetPathValue("id", noticeID) rsp, ok := noticesCmd.GET(noticesCmd, req, userState(identities.AdminAccess, 0)).(*resp) c.Assert(ok, Equals, true) @@ -867,7 +867,7 @@ func (s *apiSuite) TestNoticeNonAdminNotAllowed(c *C) { req, err := http.NewRequest("GET", "/v1/notices/"+noticeID, nil) c.Assert(err, IsNil) noticesCmd := apiCmd("/v1/notices/{id}") - s.vars = map[string]string{"id": noticeID} + req.SetPathValue("id", noticeID) rsp, ok := noticesCmd.GET(noticesCmd, req, userState(identities.ReadAccess, 1001)).(*resp) c.Assert(ok, Equals, true) diff --git a/internals/daemon/api_tasks.go b/internals/daemon/api_tasks.go index 0fe643366..d67523cb3 100644 --- a/internals/daemon/api_tasks.go +++ b/internals/daemon/api_tasks.go @@ -24,9 +24,8 @@ import ( ) func v1GetTaskWebsocket(c *Command, req *http.Request, _ *UserState) Response { - vars := muxVars(req) - taskID := vars["task-id"] - websocketID := vars["websocket-id"] + taskID := req.PathValue("taskID") + websocketID := req.PathValue("websocketID") st := c.d.overlord.State() st.Lock() diff --git a/internals/daemon/api_test.go b/internals/daemon/api_test.go index 86ce41336..79977363d 100644 --- a/internals/daemon/api_test.go +++ b/internals/daemon/api_test.go @@ -20,7 +20,6 @@ import ( "crypto/x509" "encoding/json" "math/big" - "net/http" "net/http/httptest" "os" "time" @@ -40,9 +39,6 @@ type apiSuite struct { pebbleDir string - vars map[string]string - - restoreMuxVars func() overlordStarted bool } @@ -53,7 +49,6 @@ func (s *apiSuite) SetUpTest(c *check.C) { c.Fatalf("cannot start reaper: %v", err) } - s.restoreMuxVars = FakeMuxVars(s.muxVars) s.pebbleDir = c.MkDir() } @@ -64,7 +59,6 @@ func (s *apiSuite) TearDownTest(c *check.C) { } s.d = nil s.pebbleDir = "" - s.restoreMuxVars() err := reaper.Stop() if err != nil { @@ -73,10 +67,6 @@ func (s *apiSuite) TearDownTest(c *check.C) { plan.UnregisterSectionExtension(pairingstate.PairingField) } -func (s *apiSuite) muxVars(*http.Request) map[string]string { - return s.vars -} - func (s *apiSuite) daemon(c *check.C) *Daemon { if s.d != nil { panic("called daemon() twice") diff --git a/internals/daemon/daemon.go b/internals/daemon/daemon.go index 982d3f8b3..12a85c3f9 100644 --- a/internals/daemon/daemon.go +++ b/internals/daemon/daemon.go @@ -35,7 +35,6 @@ import ( "syscall" "time" - "github.com/gorilla/mux" "gopkg.in/tomb.v2" "github.com/canonical/pebble/internals/logger" @@ -179,7 +178,7 @@ type Daemon struct { connTracker *connTracker serve *http.Server tomb tomb.Tomb - router *mux.Router + router *http.ServeMux standbyOpinions *standby.StandbyOpinions // set to what kind of restart was requested (if any) @@ -205,9 +204,8 @@ type ResponseFunc func(*Command, *http.Request, *UserState) Response // A Command routes a request to an individual per-verb ResponseFUnc type Command struct { - Path string - PathPrefix string - // + Path string + GET ResponseFunc PUT ResponseFunc POST ResponseFunc @@ -517,20 +515,14 @@ func (d *Daemon) SetDegradedMode(err error) { } func (d *Daemon) addRoutes() { - d.router = mux.NewRouter() + d.router = http.NewServeMux() for _, c := range API { c.d = d - if c.PathPrefix == "" { - d.router.Handle(c.Path, c).Name(c.Path) - } else { - d.router.PathPrefix(c.PathPrefix).Handler(c).Name(c.PathPrefix) - } + d.router.Handle(c.Path, c) } - // also maybe add a /favicon.ico handler... - - d.router.NotFoundHandler = NotFound("invalid API endpoint requested") + d.router.Handle("/", NotFound("invalid API endpoint requested")) } type connTracker struct { diff --git a/internals/daemon/daemon_test.go b/internals/daemon/daemon_test.go index 56d11c9b8..800409d85 100644 --- a/internals/daemon/daemon_test.go +++ b/internals/daemon/daemon_test.go @@ -42,7 +42,6 @@ import ( "time" "github.com/GehirnInc/crypt/sha512_crypt" - "github.com/gorilla/mux" . "gopkg.in/check.v1" "github.com/canonical/pebble/cmd" @@ -239,12 +238,16 @@ func (s *daemonSuite) TestAddCommand(c *C) { }() d := s.newDaemon(c) - d.Init() - c.Assert(d.Start(), IsNil) - defer d.Stop(nil) - result := d.router.Get(endpoint).GetHandler() - c.Assert(result, Equals, &command) + ctx := context.WithValue(context.Background(), TransportTypeKey{}, TransportTypeUnixSocket) + req, err := http.NewRequestWithContext(ctx, "GET", endpoint, nil) + c.Assert(err, IsNil) + + rec := httptest.NewRecorder() + d.router.ServeHTTP(rec, req) + c.Assert(handler.cmd, Equals, &command) + c.Assert(handler.lastMethod, Equals, "GET") + c.Assert(rec.Code, Equals, http.StatusOK) } func (s *daemonSuite) TestExplicitPaths(c *C) { @@ -632,27 +635,6 @@ func (s *daemonSuite) TestDefaultUcredUsers(c *C) { c.Check(*userSeen.UID, Equals, uint32(os.Getuid()+1)) } -func (s *daemonSuite) TestAddRoutes(c *C) { - d := s.newDaemon(c) - - expected := make([]string, len(API)) - for i, v := range API { - if v.PathPrefix != "" { - expected[i] = v.PathPrefix - continue - } - expected[i] = v.Path - } - - got := make([]string, 0, len(API)) - c.Assert(d.router.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error { - got = append(got, route.GetName()) - return nil - }), IsNil) - - c.Check(got, DeepEquals, expected) // this'll stop being true if routes are added that aren't commands (e.g. for the favicon) -} - type witnessAcceptListener struct { net.Listener @@ -1518,7 +1500,7 @@ func (s *daemonSuite) TestWritesRequireAdminAccess(c *C) { } // Task websockets (GET) is used for exec, so requires admin access too. - cmd = apiCmd("/v1/tasks/{task-id}/websocket/{websocket-id}") + cmd = apiCmd("/v1/tasks/{taskID}/websocket/{websocketID}") switch cmd.ReadAccess.(type) { case OpenAccess, UserAccess: c.Errorf("%s ReadAccess should be AdminAccess, not %T", cmd.Path, cmd.WriteAccess) diff --git a/internals/daemon/export_test.go b/internals/daemon/export_test.go index 5d8cc251e..b236bb33d 100644 --- a/internals/daemon/export_test.go +++ b/internals/daemon/export_test.go @@ -15,7 +15,6 @@ package daemon import ( - "net/http" "time" "github.com/canonical/pebble/internals/overlord" @@ -23,14 +22,6 @@ import ( "github.com/canonical/pebble/internals/overlord/state" ) -func FakeMuxVars(f func(*http.Request) map[string]string) (restore func()) { - old := muxVars - muxVars = f - return func() { - muxVars = old - } -} - func FakeStateEnsureBefore(f func(st *state.State, d time.Duration)) (restore func()) { old := stateEnsureBefore stateEnsureBefore = f