diff --git a/client.go b/client.go index 11b889b3..45f5c5a6 100644 --- a/client.go +++ b/client.go @@ -237,6 +237,9 @@ type DMap interface { // Redis client has retransmission logic in case of timeouts, pipeline // can be retransmitted and commands can be executed more than once. Pipeline(opts ...PipelineOption) (*DMapPipeline, error) + + // Close stops background routines and frees allocated resources. + Close(ctx context.Context) error } // PipelineOption is a function for defining options to control behavior of the Pipeline command. diff --git a/cluster_client.go b/cluster_client.go index 288b77f4..b7412a63 100644 --- a/cluster_client.go +++ b/cluster_client.go @@ -395,6 +395,11 @@ func (dm *ClusterDMap) LockWithTimeout(ctx context.Context, key string, timeout, }, nil } +// Close stops background routines and frees allocated resources. +func (dm *ClusterDMap) Close(_ context.Context) error { + return nil +} + // Unlock releases the distributed lock associated with the current context by using the provided context for execution. func (c *ClusterLockContext) Unlock(ctx context.Context) error { rc, err := c.dm.clusterClient.smartPick(c.dm.name, c.key) @@ -750,7 +755,7 @@ func WithRoutingTableFetchInterval(interval time.Duration) ClusterClientOption { // fetchRoutingTable updates the cluster routing table by fetching the latest version from the cluster. // It initializes the partition count if it's the first invocation. Returns an error if fetching fails. func (cl *ClusterClient) fetchRoutingTable() error { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(cl.ctx) defer cancel() routingTable, err := cl.RoutingTable(ctx) diff --git a/embedded_client.go b/embedded_client.go index cc7e75a5..a4638566 100644 --- a/embedded_client.go +++ b/embedded_client.go @@ -131,7 +131,7 @@ func (e *EmbeddedClient) RefreshMetadata(_ context.Context) error { // * Count // * Match func (dm *EmbeddedDMap) Scan(ctx context.Context, options ...ScanOption) (Iterator, error) { - cc, err := NewClusterClient([]string{dm.client.db.rt.This().String()}) + cc, err := dm.setOrGetClusterClient() if err != nil { return nil, err } @@ -280,6 +280,18 @@ func (dm *EmbeddedDMap) Put(ctx context.Context, key string, value interface{}, return nil } +// Close stops background routines and frees allocated resources. +func (dm *EmbeddedDMap) Close(ctx context.Context) error { + dm.mtx.RLock() + clusterClient := dm.clusterClient + dm.mtx.RUnlock() + + if clusterClient != nil { + return dm.clusterClient.Close(ctx) + } + return nil +} + func (e *EmbeddedClient) NewDMap(name string, options ...DMapOption) (DMap, error) { dm, err := e.db.dmap.NewDMap(name) if err != nil { diff --git a/embedded_client_test.go b/embedded_client_test.go index fd7d8f58..60a8cc11 100644 --- a/embedded_client_test.go +++ b/embedded_client_test.go @@ -17,6 +17,7 @@ package olric import ( "context" "fmt" + "runtime" "testing" "time" @@ -639,3 +640,62 @@ func TestEmbeddedClient_DMap_Put_PX_With_NX(t *testing.T) { require.NoError(t, err) assert.NotZero(t, gr.TTL()) } + +func TestEmbeddedClient_Issue263(t *testing.T) { + initNumRoutines := runtime.NumGoroutine() + + cluster := newTestOlricCluster(t) + db := cluster.addMember(t) + + e := db.NewEmbeddedClient() + ctx, cancel := context.WithCancel(context.Background()) + dm, err := e.NewDMap("mydmap") + require.NoError(t, err) + + // Create N key-value pairs: + const N = 100 + for i := range N { + key := fmt.Sprintf("key-%d", i) + value := fmt.Sprintf("value-%d", i) + err := dm.Put(ctx, key, value) + require.NoError(t, err) + } + + // Iterate M times over N keys: + const M = 100 + for range M { + iter, err := dm.Scan(ctx) + require.NoError(t, err) + for iter.Next() { + // Do nothing + } + iter.Close() + } + + require.NoError(t, dm.Close(ctx)) + require.NoError(t, e.Close(ctx)) + require.NoError(t, db.Shutdown(ctx)) + + cancel() + + assert.Equal(t, initNumRoutines, runtime.NumGoroutine()) + + runtime.GC() + time.Sleep(time.Second) + + s := runtime.MemStats{} + runtime.ReadMemStats(&s) + + const ( + KB = 1 << 10 + MB = KB << 10 + ) + + buf := make([]byte, MB) + stackSize := runtime.Stack(buf, true) + + t.Logf("Non-freed objects: %d\n", s.Mallocs-s.Frees) + t.Logf("Mem in use (KB): %d\n", s.HeapAlloc/KB) + t.Logf("Go-routines remained: %d\n", runtime.NumGoroutine()) + t.Logf("Stack traces:\n%s\n", buf[:stackSize]) +}