From 812dbbf9c38c081c4698419b81ec13e830569f75 Mon Sep 17 00:00:00 2001 From: Krzysztof Ostrowski Date: Thu, 11 Jun 2026 17:35:08 +0200 Subject: [PATCH] pkg/operator/encryption/kms/health: add KMS plugin health checker Add Checker, which dials each co-located KMSv2 plugin's UDS Status endpoint and reports per-plugin health. --- pkg/operator/encryption/kms/health/cmd.go | 84 ++++++++++-- .../encryption/kms/health/cmd_test.go | 5 + pkg/operator/encryption/kms/health/prober.go | 93 +++++++++++++ .../encryption/kms/health/prober_test.go | 124 ++++++++++++++++++ 4 files changed, 298 insertions(+), 8 deletions(-) create mode 100644 pkg/operator/encryption/kms/health/prober.go create mode 100644 pkg/operator/encryption/kms/health/prober_test.go diff --git a/pkg/operator/encryption/kms/health/cmd.go b/pkg/operator/encryption/kms/health/cmd.go index d6733848c6..4bd083761e 100644 --- a/pkg/operator/encryption/kms/health/cmd.go +++ b/pkg/operator/encryption/kms/health/cmd.go @@ -10,14 +10,29 @@ import ( "github.com/spf13/cobra" "github.com/spf13/pflag" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/apiserver/pkg/server" + k8senvelopekmsv2 "k8s.io/apiserver/pkg/storage/value/encrypt/envelope/kmsv2" "k8s.io/client-go/rest" "k8s.io/client-go/tools/clientcmd" "k8s.io/klog/v2" ) +const providerName = "kms-health-reporter" + // kmsSocketPattern matches the socket path each co-located KMSv2 plugin is // mounted at, e.g. unix:///var/run/kmsplugin/kms-1.sock. -var kmsSocketPattern = regexp.MustCompile(`^unix:///var/run/kmsplugin/kms-\d+\.sock$`) +var kmsSocketPattern = regexp.MustCompile(`^unix:///var/run/kmsplugin/kms-(\d+)\.sock$`) + +// keyIDFromSocket extracts the sequential key id captured by kmsSocketPattern, +// e.g. "1" from unix:///var/run/kmsplugin/kms-1.sock. +func keyIDFromSocket(socket string) (string, error) { + m := kmsSocketPattern.FindStringSubmatch(socket) + if m == nil { + return "", fmt.Errorf("socket %q must match %s", socket, kmsSocketPattern) + } + return m[1], nil +} // options' flag-bound fields are exported so the struct can be logged as a // whole via klog.InfoS, which JSON-marshals its values. @@ -45,7 +60,7 @@ func NewCommand(ctx context.Context, newOperatorClient func(*rest.Config) (v1hel if err := o.validate(); err != nil { return err } - return o.run() + return o.run(ctx) }, } o.addFlags(cmd.Flags()) @@ -65,10 +80,15 @@ func (o *options) validate() error { if len(o.KMSSockets) == 0 { return fmt.Errorf("--kms-sockets is required, at least one") } + socketSet := make(map[string]struct{}, len(o.KMSSockets)) for _, s := range o.KMSSockets { if !kmsSocketPattern.MatchString(s) { return fmt.Errorf("--kms-sockets entry %q must match %s", s, kmsSocketPattern) } + if _, ok := socketSet[s]; ok { + return fmt.Errorf("--kms-sockets entry %q is duplicated", s) + } + socketSet[s] = struct{}{} } if o.Interval <= 0 { @@ -87,8 +107,12 @@ func (o *options) validate() error { return nil } -func (o *options) run() error { - cfg, err := buildRESTConfig(o.Kubeconfig) +func (o *options) run(ctx context.Context) error { + ctx = setupSignalContext(ctx) + + // Empty kubeconfig falls back to the in-cluster config (service account + // token + KUBERNETES_SERVICE_HOST), which is the deployed path. + cfg, err := clientcmd.BuildConfigFromFlags("", o.Kubeconfig) if err != nil { return fmt.Errorf("build rest config: %w", err) } @@ -97,14 +121,58 @@ func (o *options) run() error { return fmt.Errorf("build operator client: %w", err) } + plugins, err := buildPlugins(ctx, o.KMSSockets, o.ReadTimeout) + if err != nil { + return err + } + prober := newProber(plugins) + klog.InfoS("kms-health-reporter starting", "config", o) + wait.JitterUntilWithContext(ctx, func(ctx context.Context) { + // Each Status RPC enforces o.ReadTimeout internally (set at dial time); + // ctx here only carries shutdown cancellation. + conditions := prober.probeAll(ctx) + // TODO: hand conditions to the writer once it lands; logging is a placeholder. + klog.InfoS("kms plugin health", "conditions", conditions) + }, o.Interval, 0.1, false) + return nil } -func buildRESTConfig(kubeconfig string) (*rest.Config, error) { - if kubeconfig != "" { - return clientcmd.BuildConfigFromFlags("", kubeconfig) +func buildPlugins(ctx context.Context, sockets []string, timeout time.Duration) ([]pluginClient, error) { + plugins := make([]pluginClient, 0, len(sockets)) + + for _, socket := range sockets { + keyID, err := keyIDFromSocket(socket) + if err != nil { + return nil, err + } + + // Unique name per plugin so the gRPC client's KMS operation metrics + // don't merge both plugins into one series. + service, err := k8senvelopekmsv2.NewGRPCService(ctx, socket, providerName+"-"+keyID, timeout) + if err != nil { + // With the current dependency version this should never happen with a validated GRPC endpoint. + return nil, fmt.Errorf("setting up grpc service failed at %q: %w", socket, err) + } + + plugins = append(plugins, pluginClient{keyID: keyID, service: service}) } - return rest.InClusterConfig() + + return plugins, nil +} + +// setupSignalContext registers for SIGTERM and SIGINT and returns a context +// that will be cancelled once a signal is received. Compare startupmonitor's +// setupSignalContext. +func setupSignalContext(baseCtx context.Context) context.Context { + shutdownCtx, cancel := context.WithCancel(baseCtx) + shutdownHandler := server.SetupSignalHandler() + go func() { + defer cancel() + <-shutdownHandler + klog.Infof("Received SIGTERM or SIGINT signal, shutting down the process.") + }() + return shutdownCtx } diff --git a/pkg/operator/encryption/kms/health/cmd_test.go b/pkg/operator/encryption/kms/health/cmd_test.go index 4be33e7f20..77f7d6108c 100644 --- a/pkg/operator/encryption/kms/health/cmd_test.go +++ b/pkg/operator/encryption/kms/health/cmd_test.go @@ -43,6 +43,11 @@ func TestValidate(t *testing.T) { name: "multiple valid sockets", mutate: func(o *options) { o.KMSSockets = append(o.KMSSockets, "unix:///var/run/kmsplugin/kms-2.sock") }, }, + { + name: "duplicate sockets", + mutate: func(o *options) { o.KMSSockets = append(o.KMSSockets, o.KMSSockets[0]) }, + wantErr: true, + }, { name: "socket missing unix scheme", mutate: func(o *options) { o.KMSSockets = []string{"/var/run/kmsplugin/kms-1.sock"} }, diff --git a/pkg/operator/encryption/kms/health/prober.go b/pkg/operator/encryption/kms/health/prober.go new file mode 100644 index 0000000000..84db02fe11 --- /dev/null +++ b/pkg/operator/encryption/kms/health/prober.go @@ -0,0 +1,93 @@ +package health + +import ( + "context" + "sync" + "time" + + kmsservice "k8s.io/kms/pkg/service" +) + +// healthzOK is the value the KMS plugin returns when healthy. +// See https://github.com/kubernetes/kubernetes/blob/master/staging/src/k8s.io/kms/apis/v2/api.proto#L39 +const healthzOK = "ok" + +const ( + statusHealthy = "healthy" + statusUnhealthy = "unhealthy" + statusError = "error" +) + +type pluginHealthReport struct { + // KeyID is the controller's sequential key id; KEKID is the KMS provider's + // encryption key id. Distinct identifiers, easy to confuse. + KeyID string + KEKID string + Status string + LastChecked time.Time + Detail string +} + +// pluginClient is the dialed handle to one co-located KMS plugin; the plugin +// itself is a separate process behind the unix socket. +type pluginClient struct { + keyID string + service kmsservice.Service +} + +type prober struct { + plugins []pluginClient + now func() time.Time +} + +func newProber(plugins []pluginClient) *prober { + return &prober{ + plugins: plugins, + now: time.Now, + } +} + +// probeAll never returns an error: a failed probe is encoded as a report +// with Status "error" so the caller always gets one entry per plugin. +// Probes run concurrently so one hung plugin doesn't delay the others; +// worst-case duration is one read-timeout, not the sum. +func (p *prober) probeAll(ctx context.Context) []pluginHealthReport { + reports := make([]pluginHealthReport, len(p.plugins)) + + var wg sync.WaitGroup + for i, plugin := range p.plugins { + wg.Go(func() { + reports[i] = p.probe(ctx, plugin) + }) + } + wg.Wait() + + return reports +} + +func (p *prober) probe(ctx context.Context, plugin pluginClient) pluginHealthReport { + report := pluginHealthReport{ + KeyID: plugin.keyID, + LastChecked: p.now(), + } + + resp, err := plugin.service.Status(ctx) + switch { + case err != nil: + report.Status = statusError + report.Detail = err.Error() + case resp == nil: + // The in-tree gRPC client never returns (nil, nil), but a misbehaving + // plugin must not panic the reporter. + report.Status = statusError + report.Detail = "kms plugin returned nil status response" + case resp.Healthz == healthzOK: + report.Status = statusHealthy + report.KEKID = resp.KeyID + default: + report.Status = statusUnhealthy + report.Detail = resp.Healthz + } + + return report +} diff --git a/pkg/operator/encryption/kms/health/prober_test.go b/pkg/operator/encryption/kms/health/prober_test.go new file mode 100644 index 0000000000..8f94ff7078 --- /dev/null +++ b/pkg/operator/encryption/kms/health/prober_test.go @@ -0,0 +1,124 @@ +package health + +import ( + "context" + "fmt" + "reflect" + "strconv" + "sync" + "testing" + "time" + + "k8s.io/apimachinery/pkg/util/wait" + kmsservice "k8s.io/kms/pkg/service" +) + +type fakeService struct { + resp *kmsservice.StatusResponse + err error +} + +func (f *fakeService) Status(context.Context) (*kmsservice.StatusResponse, error) { + return f.resp, f.err +} +func (f *fakeService) Encrypt(context.Context, string, []byte) (*kmsservice.EncryptResponse, error) { + return nil, nil +} +func (f *fakeService) Decrypt(context.Context, string, *kmsservice.DecryptRequest) ([]byte, error) { + return nil, nil +} + +func TestProber_ProbeAll(t *testing.T) { + fixed := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + p := &prober{ + plugins: []pluginClient{ + {keyID: "1", service: &fakeService{resp: &kmsservice.StatusResponse{Healthz: "ok", KeyID: "kek-abc"}}}, + {keyID: "2", service: &fakeService{err: fmt.Errorf("connection refused")}}, + {keyID: "3", service: &fakeService{resp: &kmsservice.StatusResponse{Healthz: "degraded"}}}, + {keyID: "4", service: &fakeService{}}, + }, + now: func() time.Time { return fixed }, + } + + have := p.probeAll(context.Background()) + want := []pluginHealthReport{ + {KeyID: "1", KEKID: "kek-abc", Status: "healthy", LastChecked: fixed}, + {KeyID: "2", Status: "error", Detail: "connection refused", LastChecked: fixed}, + {KeyID: "3", Status: "unhealthy", Detail: "degraded", LastChecked: fixed}, + {KeyID: "4", Status: "error", Detail: "kms plugin returned nil status response", LastChecked: fixed}, + } + if !reflect.DeepEqual(have, want) { + t.Errorf("probeAll():\n have: %+v\n want: %+v", have, want) + } +} + +// blockingService releases Status only once all expected probes have +// arrived, so the test passes only if probeAll runs them concurrently. +type blockingService struct { + *fakeService + barrier *sync.WaitGroup +} + +func (b *blockingService) Status(ctx context.Context) (*kmsservice.StatusResponse, error) { + b.barrier.Done() + b.barrier.Wait() + return b.fakeService.Status(ctx) +} + +func TestProber_ProbeAllFansOut(t *testing.T) { + const n = 3 + var barrier sync.WaitGroup + barrier.Add(n) + + plugins := make([]pluginClient, 0, n) + for i := range n { + keyID := strconv.Itoa(i + 1) + plugins = append(plugins, pluginClient{ + keyID: keyID, + service: &blockingService{ + fakeService: &fakeService{resp: &kmsservice.StatusResponse{Healthz: "ok", KeyID: "kek-" + keyID}}, + barrier: &barrier, + }, + }) + } + p := newProber(plugins) + + done := make(chan []pluginHealthReport, 1) + go func() { done <- p.probeAll(context.Background()) }() + + select { + case have := <-done: + for i, report := range have { + want := strconv.Itoa(i + 1) + if report.KeyID != want || report.KEKID != "kek-"+want { + t.Errorf("reports[%d] = {KeyID:%q KEKID:%q}, want {KeyID:%q KEKID:%q}", + i, report.KeyID, report.KEKID, want, "kek-"+want) + } + } + case <-time.After(wait.ForeverTestTimeout): + t.Fatal("probeAll timed out: probes ran sequentially or deadlocked") + } +} + +func Test_keyIDFromSocket(t *testing.T) { + tests := []struct { + socket string + want string + wantErr bool + }{ + {socket: "unix:///var/run/kmsplugin/kms-1.sock", want: "1"}, + {socket: "unix:///var/run/kmsplugin/kms-42.sock", want: "42"}, + {socket: "unix:///var/run/kmsplugin/plugin.sock", wantErr: true}, + } + for _, tt := range tests { + t.Run(tt.socket, func(t *testing.T) { + have, err := keyIDFromSocket(tt.socket) + if (err != nil) != tt.wantErr { + t.Fatalf("keyIDFromSocket(%q) err = %v, wantErr %v", tt.socket, err, tt.wantErr) + } + if have != tt.want { + t.Errorf("keyIDFromSocket(%q) = %q, want %q", tt.socket, have, tt.want) + } + }) + } +}