@@ -2,15 +2,17 @@ package rpc
22
33import (
44 "fmt"
5- "github.com/canopy-network/canopy/controller"
6- "github.com/canopy-network/canopy/lib"
7- "github.com/gorilla/websocket"
8- "github.com/julienschmidt/httprouter"
95 "net/http"
106 "net/url"
7+ "slices"
118 "strconv"
129 "sync"
1310 "time"
11+
12+ "github.com/canopy-network/canopy/controller"
13+ "github.com/canopy-network/canopy/lib"
14+ "github.com/gorilla/websocket"
15+ "github.com/julienschmidt/httprouter"
1416)
1517
1618/* This file implements the client & server logic for the 'root-chain info' and corresponding 'on-demand' calls to the rpc */
@@ -23,7 +25,7 @@ const chainIdParamName = "chainId"
2325type RCManager struct {
2426 c lib.Config // the global node config
2527 subscriptions map [uint64 ]* RCSubscription // chainId -> subscription
26- subscribers map [uint64 ]* RCSubscriber // chainId -> subscriber
28+ subscribers map [uint64 ][] * RCSubscriber // chainId -> subscribers
2729 l * sync.Mutex // thread safety
2830 afterRCUpdate func (info * lib.RootChainInfo ) // callback after the root chain info update
2931 upgrader websocket.Upgrader // upgrade http connection to ws
@@ -36,7 +38,7 @@ func NewRCManager(controller *controller.Controller, config lib.Config, logger l
3638 manager = & RCManager {
3739 c : config ,
3840 subscriptions : make (map [uint64 ]* RCSubscription ),
39- subscribers : make (map [uint64 ]* RCSubscriber ),
41+ subscribers : make (map [uint64 ][] * RCSubscriber ),
4042 l : controller .Mutex ,
4143 afterRCUpdate : controller .UpdateRootChainInfo ,
4244 upgrader : websocket.Upgrader {CheckOrigin : func (r * http.Request ) bool { return true }},
@@ -65,14 +67,13 @@ func (r *RCManager) Publish(chainId uint64, info *lib.RootChainInfo) {
6567 return
6668 }
6769 // for each ws client
68- for _ , subscriber := range r .subscribers {
69- // skip if not this chain id
70- if subscriber .chainId != chainId {
71- continue
72- }
70+ for _ , subscriber := range r .subscribers [chainId ] {
7371 // publish to each client
7472 if e := subscriber .conn .WriteMessage (websocket .BinaryMessage , protoBytes ); e != nil {
75- subscriber .Stop (e )
73+ // defer the Stop() call to prevent the slice modification during iteration.
74+ // since Stop() removes the subscriber from r.subscribers, immediate execution
75+ // would affect the slice that is currently being iterated.
76+ defer subscriber .Stop (e )
7677 }
7778 }
7879}
@@ -82,10 +83,17 @@ func (r *RCManager) ChainIds() (list []uint64) {
8283 // de-duplicate the results
8384 deDupe := lib .NewDeDuplicator [uint64 ]()
8485 // for each client
85- for _ , client := range r .subscribers {
86+ for chainId , chainSubscribers := range r .subscribers {
8687 // if the client chain id isn't empty and not duplicate
87- if client .chainId != 0 && ! deDupe .Found (client .chainId ) {
88- list = append (list , client .chainId )
88+ for _ , subscriber := range chainSubscribers {
89+ if subscriber .chainId != chainId {
90+ // remove subscriber with incorrect chain id
91+ subscriber .Stop (lib .ErrWrongChainId ())
92+ continue
93+ }
94+ if subscriber .chainId != 0 && ! deDupe .Found (subscriber .chainId ) {
95+ list = append (list , subscriber .chainId )
96+ }
8997 }
9098 }
9199 return
@@ -424,16 +432,18 @@ func (r *RCManager) AddSubscriber(subscriber *RCSubscriber) {
424432 r .l .Lock ()
425433 defer r .l .Unlock ()
426434 // add to the map
427- r .subscribers [subscriber .chainId ] = subscriber
435+ r .subscribers [subscriber .chainId ] = append ( r . subscribers [ subscriber . chainId ], subscriber )
428436}
429437
430438// RemoveSubscriber() gracefully deletes a RC subscriber
431- func (r * RCManager ) RemoveSubscriber (chainId uint64 ) {
439+ func (r * RCManager ) RemoveSubscriber (chainId uint64 , subscriber * RCSubscriber ) {
432440 // lock for thread safety
433441 r .l .Lock ()
434442 defer r .l .Unlock ()
435- // remove from the map
436- delete (r .subscribers , chainId )
443+ // remove from the slice
444+ r .subscribers [chainId ] = slices .DeleteFunc (r .subscribers [chainId ], func (sub * RCSubscriber ) bool {
445+ return sub == subscriber
446+ })
437447}
438448
439449// Stop() stops the client
@@ -445,5 +455,5 @@ func (r *RCSubscriber) Stop(err error) {
445455 r .log .Error (err .Error ())
446456 }
447457 // remove from the manager
448- r .manager .RemoveSubscriber (r .chainId )
458+ r .manager .RemoveSubscriber (r .chainId , r )
449459}
0 commit comments