Skip to content

Commit 62da857

Browse files
authored
Merge pull request #180 from canopy-network/issue-#164
Issue-#164: Allow multiple subscribers per `ChainID` to ws updates
2 parents 2cbaf9a + 7f0a5b7 commit 62da857

1 file changed

Lines changed: 30 additions & 20 deletions

File tree

cmd/rpc/sock.go

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,17 @@ package rpc
22

33
import (
44
"fmt"
5-
"github.com/canopy-network/canopy/controller"
6-
"github.com/canopy-network/canopy/lib"
7-
"github.com/gorilla/websocket"
8-
"github.com/julienschmidt/httprouter"
95
"net/http"
106
"net/url"
7+
"slices"
118
"strconv"
129
"sync"
1310
"time"
11+
12+
"github.com/canopy-network/canopy/controller"
13+
"github.com/canopy-network/canopy/lib"
14+
"github.com/gorilla/websocket"
15+
"github.com/julienschmidt/httprouter"
1416
)
1517

1618
/* This file implements the client & server logic for the 'root-chain info' and corresponding 'on-demand' calls to the rpc */
@@ -23,7 +25,7 @@ const chainIdParamName = "chainId"
2325
type RCManager struct {
2426
c lib.Config // the global node config
2527
subscriptions map[uint64]*RCSubscription // chainId -> subscription
26-
subscribers map[uint64]*RCSubscriber // chainId -> subscriber
28+
subscribers map[uint64][]*RCSubscriber // chainId -> subscribers
2729
l *sync.Mutex // thread safety
2830
afterRCUpdate func(info *lib.RootChainInfo) // callback after the root chain info update
2931
upgrader websocket.Upgrader // upgrade http connection to ws
@@ -36,7 +38,7 @@ func NewRCManager(controller *controller.Controller, config lib.Config, logger l
3638
manager = &RCManager{
3739
c: config,
3840
subscriptions: make(map[uint64]*RCSubscription),
39-
subscribers: make(map[uint64]*RCSubscriber),
41+
subscribers: make(map[uint64][]*RCSubscriber),
4042
l: controller.Mutex,
4143
afterRCUpdate: controller.UpdateRootChainInfo,
4244
upgrader: websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }},
@@ -65,14 +67,13 @@ func (r *RCManager) Publish(chainId uint64, info *lib.RootChainInfo) {
6567
return
6668
}
6769
// for each ws client
68-
for _, subscriber := range r.subscribers {
69-
// skip if not this chain id
70-
if subscriber.chainId != chainId {
71-
continue
72-
}
70+
for _, subscriber := range r.subscribers[chainId] {
7371
// publish to each client
7472
if e := subscriber.conn.WriteMessage(websocket.BinaryMessage, protoBytes); e != nil {
75-
subscriber.Stop(e)
73+
// defer the Stop() call to prevent the slice modification during iteration.
74+
// since Stop() removes the subscriber from r.subscribers, immediate execution
75+
// would affect the slice that is currently being iterated.
76+
defer subscriber.Stop(e)
7677
}
7778
}
7879
}
@@ -82,10 +83,17 @@ func (r *RCManager) ChainIds() (list []uint64) {
8283
// de-duplicate the results
8384
deDupe := lib.NewDeDuplicator[uint64]()
8485
// for each client
85-
for _, client := range r.subscribers {
86+
for chainId, chainSubscribers := range r.subscribers {
8687
// if the client chain id isn't empty and not duplicate
87-
if client.chainId != 0 && !deDupe.Found(client.chainId) {
88-
list = append(list, client.chainId)
88+
for _, subscriber := range chainSubscribers {
89+
if subscriber.chainId != chainId {
90+
// remove subscriber with incorrect chain id
91+
subscriber.Stop(lib.ErrWrongChainId())
92+
continue
93+
}
94+
if subscriber.chainId != 0 && !deDupe.Found(subscriber.chainId) {
95+
list = append(list, subscriber.chainId)
96+
}
8997
}
9098
}
9199
return
@@ -424,16 +432,18 @@ func (r *RCManager) AddSubscriber(subscriber *RCSubscriber) {
424432
r.l.Lock()
425433
defer r.l.Unlock()
426434
// add to the map
427-
r.subscribers[subscriber.chainId] = subscriber
435+
r.subscribers[subscriber.chainId] = append(r.subscribers[subscriber.chainId], subscriber)
428436
}
429437

430438
// RemoveSubscriber() gracefully deletes a RC subscriber
431-
func (r *RCManager) RemoveSubscriber(chainId uint64) {
439+
func (r *RCManager) RemoveSubscriber(chainId uint64, subscriber *RCSubscriber) {
432440
// lock for thread safety
433441
r.l.Lock()
434442
defer r.l.Unlock()
435-
// remove from the map
436-
delete(r.subscribers, chainId)
443+
// remove from the slice
444+
r.subscribers[chainId] = slices.DeleteFunc(r.subscribers[chainId], func(sub *RCSubscriber) bool {
445+
return sub == subscriber
446+
})
437447
}
438448

439449
// Stop() stops the client
@@ -445,5 +455,5 @@ func (r *RCSubscriber) Stop(err error) {
445455
r.log.Error(err.Error())
446456
}
447457
// remove from the manager
448-
r.manager.RemoveSubscriber(r.chainId)
458+
r.manager.RemoveSubscriber(r.chainId, r)
449459
}

0 commit comments

Comments
 (0)