Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 76 additions & 8 deletions pkg/operator/encryption/kms/health/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,29 @@ import (
"github.com/spf13/cobra"
"github.com/spf13/pflag"

"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/apiserver/pkg/server"
k8senvelopekmsv2 "k8s.io/apiserver/pkg/storage/value/encrypt/envelope/kmsv2"
"k8s.io/client-go/rest"
"k8s.io/client-go/tools/clientcmd"
"k8s.io/klog/v2"
)

const providerName = "kms-health-reporter"

// kmsSocketPattern matches the socket path each co-located KMSv2 plugin is
// mounted at, e.g. unix:///var/run/kmsplugin/kms-1.sock.
var kmsSocketPattern = regexp.MustCompile(`^unix:///var/run/kmsplugin/kms-\d+\.sock$`)
var kmsSocketPattern = regexp.MustCompile(`^unix:///var/run/kmsplugin/kms-(\d+)\.sock$`)

// keyIDFromSocket extracts the sequential key id captured by kmsSocketPattern,
// e.g. "1" from unix:///var/run/kmsplugin/kms-1.sock.
func keyIDFromSocket(socket string) (string, error) {
m := kmsSocketPattern.FindStringSubmatch(socket)
if m == nil {
return "", fmt.Errorf("socket %q must match %s", socket, kmsSocketPattern)
}
return m[1], nil
}

// options' flag-bound fields are exported so the struct can be logged as a
// whole via klog.InfoS, which JSON-marshals its values.
Expand Down Expand Up @@ -45,7 +60,7 @@ func NewCommand(ctx context.Context, newOperatorClient func(*rest.Config) (v1hel
if err := o.validate(); err != nil {
return err
}
return o.run()
return o.run(ctx)
},
}
o.addFlags(cmd.Flags())
Expand All @@ -65,10 +80,15 @@ func (o *options) validate() error {
if len(o.KMSSockets) == 0 {
return fmt.Errorf("--kms-sockets is required, at least one")
}
socketSet := make(map[string]struct{}, len(o.KMSSockets))
for _, s := range o.KMSSockets {
if !kmsSocketPattern.MatchString(s) {
return fmt.Errorf("--kms-sockets entry %q must match %s", s, kmsSocketPattern)
}
if _, ok := socketSet[s]; ok {
return fmt.Errorf("--kms-sockets entry %q is duplicated", s)
}
socketSet[s] = struct{}{}
}

if o.Interval <= 0 {
Expand All @@ -87,8 +107,12 @@ func (o *options) validate() error {
return nil
}

func (o *options) run() error {
cfg, err := buildRESTConfig(o.Kubeconfig)
func (o *options) run(ctx context.Context) error {
ctx = setupSignalContext(ctx)

// Empty kubeconfig falls back to the in-cluster config (service account
// token + KUBERNETES_SERVICE_HOST), which is the deployed path.
cfg, err := clientcmd.BuildConfigFromFlags("", o.Kubeconfig)

@ardaguclu ardaguclu Jun 12, 2026

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Ideally this should be in validate complete function not run. But this is definitely not a blocker.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will give it a try. Could be cool.

if err != nil {
return fmt.Errorf("build rest config: %w", err)
}
Expand All @@ -97,14 +121,58 @@ func (o *options) run() error {
return fmt.Errorf("build operator client: %w", err)
}

plugins, err := buildPlugins(ctx, o.KMSSockets, o.ReadTimeout)
if err != nil {
return err
}
prober := newProber(plugins)

klog.InfoS("kms-health-reporter starting", "config", o)

wait.JitterUntilWithContext(ctx, func(ctx context.Context) {
// Each Status RPC enforces o.ReadTimeout internally (set at dial time);
// ctx here only carries shutdown cancellation.
conditions := prober.probeAll(ctx)
// TODO: hand conditions to the writer once it lands; logging is a placeholder.
klog.InfoS("kms plugin health", "conditions", conditions)
}, o.Interval, 0.1, false)

return nil
}

func buildRESTConfig(kubeconfig string) (*rest.Config, error) {
if kubeconfig != "" {
return clientcmd.BuildConfigFromFlags("", kubeconfig)
func buildPlugins(ctx context.Context, sockets []string, timeout time.Duration) ([]pluginClient, error) {
plugins := make([]pluginClient, 0, len(sockets))

for _, socket := range sockets {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just realized that we don't check for duplicates, in theory a "user" can pass --kms-sockets kms-1.sock,kms-1.sock - we could add a check in a new pr.

keyID, err := keyIDFromSocket(socket)
if err != nil {
return nil, err
}

// Unique name per plugin so the gRPC client's KMS operation metrics
// don't merge both plugins into one series.
service, err := k8senvelopekmsv2.NewGRPCService(ctx, socket, providerName+"-"+keyID, timeout)
if err != nil {
// With the current dependency version this should never happen with a validated GRPC endpoint.
return nil, fmt.Errorf("setting up grpc service failed at %q: %w", socket, err)
}

plugins = append(plugins, pluginClient{keyID: keyID, service: service})
}
return rest.InClusterConfig()

return plugins, nil
}

// setupSignalContext registers for SIGTERM and SIGINT and returns a context
// that will be cancelled once a signal is received. Compare startupmonitor's
// setupSignalContext.
func setupSignalContext(baseCtx context.Context) context.Context {
shutdownCtx, cancel := context.WithCancel(baseCtx)
shutdownHandler := server.SetupSignalHandler()
go func() {
defer cancel()
<-shutdownHandler
klog.Infof("Received SIGTERM or SIGINT signal, shutting down the process.")
}()
return shutdownCtx
}
5 changes: 5 additions & 0 deletions pkg/operator/encryption/kms/health/cmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ func TestValidate(t *testing.T) {
name: "multiple valid sockets",
mutate: func(o *options) { o.KMSSockets = append(o.KMSSockets, "unix:///var/run/kmsplugin/kms-2.sock") },
},
{
name: "duplicate sockets",
mutate: func(o *options) { o.KMSSockets = append(o.KMSSockets, o.KMSSockets[0]) },
wantErr: true,
},
{
name: "socket missing unix scheme",
mutate: func(o *options) { o.KMSSockets = []string{"/var/run/kmsplugin/kms-1.sock"} },
Expand Down
93 changes: 93 additions & 0 deletions pkg/operator/encryption/kms/health/prober.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package health

import (
"context"
"sync"
"time"

kmsservice "k8s.io/kms/pkg/service"
)

// healthzOK is the value the KMS plugin returns when healthy.
// See https://github.com/kubernetes/kubernetes/blob/master/staging/src/k8s.io/kms/apis/v2/api.proto#L39
const healthzOK = "ok"

const (
statusHealthy = "healthy"
statusUnhealthy = "unhealthy"
statusError = "error"
)

type pluginHealthReport struct {
// KeyID is the controller's sequential key id; KEKID is the KMS provider's
// encryption key id. Distinct identifiers, easy to confuse.
KeyID string
KEKID string
Status string
LastChecked time.Time
Detail string
}

// pluginClient is the dialed handle to one co-located KMS plugin; the plugin
// itself is a separate process behind the unix socket.
type pluginClient struct {
keyID string
service kmsservice.Service
}

type prober struct {
plugins []pluginClient
now func() time.Time
}

func newProber(plugins []pluginClient) *prober {
return &prober{
plugins: plugins,
now: time.Now,
}
}

// probeAll never returns an error: a failed probe is encoded as a report
// with Status "error" so the caller always gets one entry per plugin.
// Probes run concurrently so one hung plugin doesn't delay the others;
// worst-case duration is one read-timeout, not the sum.
func (p *prober) probeAll(ctx context.Context) []pluginHealthReport {
reports := make([]pluginHealthReport, len(p.plugins))

var wg sync.WaitGroup
for i, plugin := range p.plugins {
wg.Go(func() {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that is a clean usage of go routines. didn't know you can use wg.Go

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I never used it, by I heard already of it, so here I am giving it a try.

reports[i] = p.probe(ctx, plugin)
})
}
wg.Wait()

return reports
}

func (p *prober) probe(ctx context.Context, plugin pluginClient) pluginHealthReport {
report := pluginHealthReport{
KeyID: plugin.keyID,
LastChecked: p.now(),
}

resp, err := plugin.service.Status(ctx)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just realized that resp is a pointer and can be nil. should we add a check ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NullPointerException. Always worth checking.

switch {
case err != nil:
report.Status = statusError
report.Detail = err.Error()
case resp == nil:
// The in-tree gRPC client never returns (nil, nil), but a misbehaving
// plugin must not panic the reporter.
report.Status = statusError
report.Detail = "kms plugin returned nil status response"
case resp.Healthz == healthzOK:
report.Status = statusHealthy
report.KEKID = resp.KeyID
default:
report.Status = statusUnhealthy
report.Detail = resp.Healthz
}

return report
}
124 changes: 124 additions & 0 deletions pkg/operator/encryption/kms/health/prober_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package health

import (
"context"
"fmt"
"reflect"
"strconv"
"sync"
"testing"
"time"

"k8s.io/apimachinery/pkg/util/wait"
kmsservice "k8s.io/kms/pkg/service"
)

type fakeService struct {
resp *kmsservice.StatusResponse
err error
}

func (f *fakeService) Status(context.Context) (*kmsservice.StatusResponse, error) {
return f.resp, f.err
}
func (f *fakeService) Encrypt(context.Context, string, []byte) (*kmsservice.EncryptResponse, error) {
return nil, nil
}
func (f *fakeService) Decrypt(context.Context, string, *kmsservice.DecryptRequest) ([]byte, error) {
return nil, nil
}

func TestProber_ProbeAll(t *testing.T) {
fixed := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
p := &prober{
plugins: []pluginClient{
{keyID: "1", service: &fakeService{resp: &kmsservice.StatusResponse{Healthz: "ok", KeyID: "kek-abc"}}},
{keyID: "2", service: &fakeService{err: fmt.Errorf("connection refused")}},
{keyID: "3", service: &fakeService{resp: &kmsservice.StatusResponse{Healthz: "degraded"}}},
{keyID: "4", service: &fakeService{}},
},
now: func() time.Time { return fixed },
}

have := p.probeAll(context.Background())
want := []pluginHealthReport{
{KeyID: "1", KEKID: "kek-abc", Status: "healthy", LastChecked: fixed},
{KeyID: "2", Status: "error", Detail: "connection refused", LastChecked: fixed},
{KeyID: "3", Status: "unhealthy", Detail: "degraded", LastChecked: fixed},
{KeyID: "4", Status: "error", Detail: "kms plugin returned nil status response", LastChecked: fixed},
}
if !reflect.DeepEqual(have, want) {
t.Errorf("probeAll():\n have: %+v\n want: %+v", have, want)
}
}

// blockingService releases Status only once all expected probes have
// arrived, so the test passes only if probeAll runs them concurrently.
type blockingService struct {
*fakeService
barrier *sync.WaitGroup
}

func (b *blockingService) Status(ctx context.Context) (*kmsservice.StatusResponse, error) {
b.barrier.Done()
b.barrier.Wait()
return b.fakeService.Status(ctx)
}

func TestProber_ProbeAllFansOut(t *testing.T) {
const n = 3
var barrier sync.WaitGroup
barrier.Add(n)

plugins := make([]pluginClient, 0, n)
for i := range n {
keyID := strconv.Itoa(i + 1)
plugins = append(plugins, pluginClient{
keyID: keyID,
service: &blockingService{
fakeService: &fakeService{resp: &kmsservice.StatusResponse{Healthz: "ok", KeyID: "kek-" + keyID}},
barrier: &barrier,
},
})
}
p := newProber(plugins)

done := make(chan []pluginHealthReport, 1)
go func() { done <- p.probeAll(context.Background()) }()

select {
case have := <-done:
for i, report := range have {
want := strconv.Itoa(i + 1)
if report.KeyID != want || report.KEKID != "kek-"+want {
t.Errorf("reports[%d] = {KeyID:%q KEKID:%q}, want {KeyID:%q KEKID:%q}",
i, report.KeyID, report.KEKID, want, "kek-"+want)
}
}
case <-time.After(wait.ForeverTestTimeout):
t.Fatal("probeAll timed out: probes ran sequentially or deadlocked")
}
}

func Test_keyIDFromSocket(t *testing.T) {
tests := []struct {
socket string
want string
wantErr bool
}{
{socket: "unix:///var/run/kmsplugin/kms-1.sock", want: "1"},
{socket: "unix:///var/run/kmsplugin/kms-42.sock", want: "42"},
{socket: "unix:///var/run/kmsplugin/plugin.sock", wantErr: true},
}
for _, tt := range tests {
t.Run(tt.socket, func(t *testing.T) {
have, err := keyIDFromSocket(tt.socket)
if (err != nil) != tt.wantErr {
t.Fatalf("keyIDFromSocket(%q) err = %v, wantErr %v", tt.socket, err, tt.wantErr)
}
if have != tt.want {
t.Errorf("keyIDFromSocket(%q) = %q, want %q", tt.socket, have, tt.want)
}
})
}
}