From 22ba4c912a248169e804e39bec5018812a569fae Mon Sep 17 00:00:00 2001 From: Diego Date: Mon, 1 Jun 2026 08:43:33 -0300 Subject: [PATCH] perf(transport): Optimize protocol name construction on substream open --- src/multistream_select/dialer_select.rs | 88 ++++++++++--------- src/multistream_select/listener_select.rs | 31 ++++--- src/protocol/connection.rs | 25 +++--- .../tests/substream_validation.rs | 4 +- src/protocol/protocol_set.rs | 30 +++---- src/protocol/transport_service.rs | 43 +++++---- src/transport/quic/connection.rs | 59 +++++++------ src/transport/tcp/connection.rs | 66 ++++++++------ src/transport/webrtc/connection.rs | 46 ++++------ src/transport/websocket/connection.rs | 64 ++++++++------ src/types/protocol.rs | 10 +++ 11 files changed, 254 insertions(+), 212 deletions(-) diff --git a/src/multistream_select/dialer_select.rs b/src/multistream_select/dialer_select.rs index e95e129af..dc7a96c9c 100644 --- a/src/multistream_select/dialer_select.rs +++ b/src/multistream_select/dialer_select.rs @@ -39,6 +39,7 @@ use std::{ convert::TryFrom as _, iter, mem, pin::Pin, + sync::Arc, task::{Context, Poll}, }; @@ -320,11 +321,15 @@ enum HandshakeState { /// `multistream-select` dialer handshake state. #[derive(Debug)] pub struct WebRtcDialerState { - /// Proposed main protocol. - protocol: ProtocolName, + /// Protocols to negotiate in preference order. + /// + /// `protocols[0]` is the main protocol; `protocols[1..]` are fallbacks. The list is + /// shared via [`Arc`] so it can be precomputed once per protocol handler and reused + /// across substreams without per-substream allocation. + protocols: Arc<[ProtocolName]>, - /// Fallback names of the main protocol. - fallback_names: Vec, + /// Index into [`Self::protocols`] of the protocol currently being proposed. + current_index: usize, /// Dialer handshake state. state: HandshakeState, @@ -333,15 +338,13 @@ pub struct WebRtcDialerState { impl WebRtcDialerState { /// Propose protocol to remote peer. /// - /// `fallback_names` must be in preference order, the first element is the - /// next protocol to try. + /// `protocols[0]` is proposed first; the rest are tried in order via + /// [`Self::propose_next_fallback`] when the peer rejects the current one. /// /// Return [`WebRtcDialerState`] which is used to drive forward the negotiation and an encoded /// `multistream-select` message that contains the protocol proposal for the substream. - pub fn propose( - protocol: ProtocolName, - mut fallback_names: Vec, - ) -> crate::Result<(Self, Vec)> { + pub fn propose(protocols: Arc<[ProtocolName]>) -> crate::Result<(Self, Vec)> { + let protocol = protocols.first().ok_or(Error::InvalidData)?; let message = webrtc_encode_multistream_message( Message::Protocol( Protocol::try_from(protocol.as_ref()).map_err(|_| Error::InvalidData)?, @@ -351,33 +354,35 @@ impl WebRtcDialerState { .freeze() .to_vec(); - // Reverse fallback_names so that we can pop from it. - fallback_names.reverse(); - Ok(( Self { - protocol, - fallback_names, + protocols, + current_index: 0, state: HandshakeState::WaitingResponse, }, message, )) } + /// Currently proposed protocol. + fn current_protocol(&self) -> &ProtocolName { + &self.protocols[self.current_index] + } + /// Propose the next fallback protocol to the remote peer. /// /// Returns `None` if there are no more fallback protocols to try. /// Returns `Some(message)` with the encoded message to send, containing the protocol name. pub fn propose_next_fallback(&mut self) -> crate::Result>> { - let Some(next) = self.fallback_names.pop() else { + if self.current_index + 1 >= self.protocols.len() { return Ok(None); - }; - - self.protocol = next; + } + self.current_index += 1; let message = webrtc_encode_multistream_message( Message::Protocol( - Protocol::try_from(self.protocol.as_ref()).map_err(|_| Error::InvalidData)?, + Protocol::try_from(self.current_protocol().as_ref()) + .map_err(|_| Error::InvalidData)?, ), false, )? @@ -460,9 +465,9 @@ impl WebRtcDialerState { return Err(crate::error::NegotiationError::StateMismatch); } - if self.protocol.as_bytes() == protocol.as_ref() { + if self.current_protocol().as_bytes() == protocol.as_ref() { check_trailing_bytes(&remaining); - return Ok(HandshakeResult::Succeeded(self.protocol.clone())); + return Ok(HandshakeResult::Succeeded(self.current_protocol().clone())); } return Err(crate::error::NegotiationError::MultistreamSelectError( @@ -835,7 +840,8 @@ mod tests { #[test] fn propose() { let (mut dialer_state, message) = - WebRtcDialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap(); + WebRtcDialerState::propose(Arc::from([ProtocolName::from("/13371338/proto/1")])) + .unwrap(); let mut bytes = BytesMut::with_capacity(32); bytes.put_u8(MSG_MULTISTREAM_1_0.len() as u8); @@ -852,10 +858,10 @@ mod tests { #[test] fn propose_with_fallback() { - let (mut dialer_state, message) = WebRtcDialerState::propose( + let (mut dialer_state, message) = WebRtcDialerState::propose(Arc::from([ ProtocolName::from("/13371338/proto/1"), - vec![ProtocolName::from("/sup/proto/1")], - ) + ProtocolName::from("/sup/proto/1"), + ])) .unwrap(); // Initial message should only contain the main protocol, not the fallback. @@ -874,10 +880,10 @@ mod tests { #[test] fn propose_next_fallback() { - let (mut dialer_state, _message) = WebRtcDialerState::propose( + let (mut dialer_state, _message) = WebRtcDialerState::propose(Arc::from([ ProtocolName::from("/13371338/proto/1"), - vec![ProtocolName::from("/sup/proto/1")], - ) + ProtocolName::from("/sup/proto/1"), + ])) .unwrap(); // Simulate receiving header-only response, transitioning to WaitingProtocol. @@ -943,7 +949,8 @@ mod tests { message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap(); let (mut dialer_state, _message) = - WebRtcDialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap(); + WebRtcDialerState::propose(Arc::from([ProtocolName::from("/13371338/proto/1")])) + .unwrap(); match dialer_state.register_response(bytes.freeze().to_vec()) { Ok(HandshakeResult::NotReady) => {} @@ -966,7 +973,8 @@ mod tests { let response = bytes.freeze().to_vec(); let (mut dialer_state, _message) = - WebRtcDialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap(); + WebRtcDialerState::propose(Arc::from([ProtocolName::from("/13371338/proto/1")])) + .unwrap(); match dialer_state.register_response(response) { Err(error::NegotiationError::MultistreamSelectError(NegotiationError::Failed)) => {} @@ -983,10 +991,10 @@ mod tests { .unwrap() .freeze(); - let (mut dialer_state, _message) = WebRtcDialerState::propose( + let (mut dialer_state, _message) = WebRtcDialerState::propose(Arc::from([ ProtocolName::from("/13371338/proto/1"), - vec![ProtocolName::from("/sup/proto/1")], - ) + ProtocolName::from("/sup/proto/1"), + ])) .unwrap(); match dialer_state.register_response(message.to_vec()) { @@ -1006,10 +1014,10 @@ mod tests { .unwrap() .freeze(); - let (mut dialer_state, _message) = WebRtcDialerState::propose( + let (mut dialer_state, _message) = WebRtcDialerState::propose(Arc::from([ ProtocolName::from("/13371338/proto/1"), - vec![ProtocolName::from("/sup/proto/1")], - ) + ProtocolName::from("/sup/proto/1"), + ])) .unwrap(); dialer_state.propose_next_fallback(); @@ -1024,10 +1032,10 @@ mod tests { #[test] fn reject_unproposed_fallback_confirmation() { - let (mut dialer_state, _message) = WebRtcDialerState::propose( + let (mut dialer_state, _message) = WebRtcDialerState::propose(Arc::from([ ProtocolName::from("/13371338/proto/1"), - vec![ProtocolName::from("/sup/proto/1")], - ) + ProtocolName::from("/sup/proto/1"), + ])) .unwrap(); // The dialer has only proposed the main protocol. The fallback is stored for a diff --git a/src/multistream_select/listener_select.rs b/src/multistream_select/listener_select.rs index 672e1edd0..4dfc0dadb 100644 --- a/src/multistream_select/listener_select.rs +++ b/src/multistream_select/listener_select.rs @@ -389,11 +389,14 @@ fn decode_multistream_message(data: &mut Bytes) -> Result, +pub fn webrtc_listener_negotiate<'a, I>( + supported_protocols: I, mut payload: Bytes, header_received: bool, -) -> crate::Result { +) -> crate::Result +where + I: IntoIterator, +{ // Save for zero-copy header echo (Bytes::clone is O(1)). let raw_payload = payload.clone(); @@ -448,7 +451,7 @@ pub fn webrtc_listener_negotiate( "listener: checking protocol", ); - for supported in supported_protocols.iter() { + for supported in supported_protocols { if protocol.as_ref() == supported.as_bytes() { return Ok(ListenerSelectResult::Accepted { protocol: supported.clone(), @@ -494,7 +497,7 @@ mod tests { .unwrap() .freeze(); - match webrtc_listener_negotiate(local_protocols, message, false) { + match webrtc_listener_negotiate(&local_protocols, message, false) { Err(error) => panic!("error received: {error:?}"), Ok(ListenerSelectResult::Rejected { .. }) => panic!("message rejected"), Ok(ListenerSelectResult::PendingProtocol { .. }) => panic!("unexpected pending"), @@ -523,7 +526,7 @@ mod tests { .unwrap() .freeze(); - match webrtc_listener_negotiate(local_protocols, message, false) { + match webrtc_listener_negotiate(&local_protocols, message, false) { Err(error) => assert!(std::matches!( error, Error::NegotiationError(error::NegotiationError::ParseError( @@ -549,7 +552,7 @@ mod tests { Message::Header(HeaderLine::V1).encode(&mut bytes).unwrap(); let payload = Bytes::from(UnsignedVarint::encode(bytes).unwrap()); - match webrtc_listener_negotiate(local_protocols, payload.clone(), false) { + match webrtc_listener_negotiate(&local_protocols, payload.clone(), false) { Ok(ListenerSelectResult::PendingProtocol { message }) => { assert_eq!(message, payload); } @@ -574,7 +577,7 @@ mod tests { .unwrap(); let payload = Bytes::from(UnsignedVarint::encode(bytes).unwrap()); - match webrtc_listener_negotiate(local_protocols, payload, false) { + match webrtc_listener_negotiate(&local_protocols, payload, false) { Err(error) => assert!(std::matches!( error, Error::NegotiationError(error::NegotiationError::MultistreamSelectError( @@ -601,7 +604,7 @@ mod tests { .unwrap() .freeze(); - match webrtc_listener_negotiate(local_protocols, message, false) { + match webrtc_listener_negotiate(&local_protocols, message, false) { Err(error) => panic!("error received: {error:?}"), Ok(ListenerSelectResult::Rejected { message }) => { assert_eq!( @@ -625,7 +628,7 @@ mod tests { Message::Header(HeaderLine::V1).encode(&mut bytes).unwrap(); let header_payload = Bytes::from(UnsignedVarint::encode(bytes).unwrap()); - match webrtc_listener_negotiate(local_protocols.clone(), header_payload.clone(), false) { + match webrtc_listener_negotiate(&local_protocols, header_payload.clone(), false) { Ok(ListenerSelectResult::PendingProtocol { message }) => { assert_eq!(message, header_payload); } @@ -639,7 +642,7 @@ mod tests { .unwrap(); let proto1_payload = Bytes::from(UnsignedVarint::encode(bytes).unwrap()); - match webrtc_listener_negotiate(local_protocols.clone(), proto1_payload, true) { + match webrtc_listener_negotiate(&local_protocols, proto1_payload, true) { Ok(ListenerSelectResult::Rejected { message }) => { assert_eq!( message, @@ -658,7 +661,7 @@ mod tests { .unwrap(); let proto2_payload = Bytes::from(UnsignedVarint::encode(bytes).unwrap()); - match webrtc_listener_negotiate(local_protocols, proto2_payload, true) { + match webrtc_listener_negotiate(&local_protocols, proto2_payload, true) { Ok(ListenerSelectResult::Rejected { message }) => { assert_eq!( message, @@ -680,7 +683,7 @@ mod tests { Message::Header(HeaderLine::V1).encode(&mut bytes).unwrap(); let header_payload = Bytes::from(UnsignedVarint::encode(bytes).unwrap()); - match webrtc_listener_negotiate(local_protocols.clone(), header_payload.clone(), false) { + match webrtc_listener_negotiate(&local_protocols, header_payload.clone(), false) { Ok(ListenerSelectResult::PendingProtocol { message }) => { assert_eq!(message, header_payload); } @@ -694,7 +697,7 @@ mod tests { .unwrap(); let proto_payload = Bytes::from(UnsignedVarint::encode(bytes).unwrap()); - match webrtc_listener_negotiate(local_protocols, proto_payload, true) { + match webrtc_listener_negotiate(&local_protocols, proto_payload, true) { Ok(ListenerSelectResult::Accepted { protocol, .. }) => { assert_eq!(protocol, ProtocolName::from("/13371338/proto/1")); } diff --git a/src/protocol/connection.rs b/src/protocol/connection.rs index a11bee600..45872b118 100644 --- a/src/protocol/connection.rs +++ b/src/protocol/connection.rs @@ -28,6 +28,8 @@ use crate::{ use tokio::sync::mpsc::{error::TrySendError, Sender, WeakSender}; +use std::sync::Arc; + /// Connection type, from the point of view of the protocol. #[derive(Debug, Clone)] enum ConnectionType { @@ -111,12 +113,14 @@ impl ConnectionHandle { } } - /// Open substream to remote peer over `protocol` and send the acquired permit to the + /// Open substream to remote peer over `protocols` and send the acquired permit to the /// transport so it can be given to the opened substream. + /// + /// `protocols[0]` is the main protocol; the rest are fallbacks. The list is shared + /// via [`Arc`] so callers reuse the same precomputed list for every substream. pub fn open_substream( &mut self, - protocol: ProtocolName, - fallback_names: Vec, + protocols: Arc<[ProtocolName]>, substream_id: SubstreamId, permit: Permit, keep_alive: SubstreamKeepAlive, @@ -127,8 +131,7 @@ impl ConnectionHandle { inactive.upgrade().ok_or(SubstreamError::ConnectionClosed)?, } .try_send(ProtocolCommand::OpenSubstream { - protocol: protocol.clone(), - fallback_names, + protocols, substream_id, connection_id: self.connection_id, permit, @@ -214,8 +217,7 @@ mod tests { let permit = handle.try_get_permit().unwrap(); let result = handle.open_substream( - ProtocolName::from("/protocol/1"), - Vec::new(), + Arc::from([ProtocolName::from("/protocol/1")]), SubstreamId::new(), permit, SubstreamKeepAlive::Yes, @@ -234,8 +236,7 @@ mod tests { drop(_rx); let result = handle.open_substream( - ProtocolName::from("/protocol/1"), - Vec::new(), + Arc::from([ProtocolName::from("/protocol/1")]), SubstreamId::new(), permit, SubstreamKeepAlive::Yes, @@ -252,8 +253,7 @@ mod tests { let permit = handle.try_get_permit().unwrap(); let result = handle.open_substream( - ProtocolName::from("/protocol/1"), - Vec::new(), + Arc::from([ProtocolName::from("/protocol/1")]), SubstreamId::new(), permit, SubstreamKeepAlive::Yes, @@ -262,8 +262,7 @@ mod tests { let permit = handle.try_get_permit().unwrap(); match handle.open_substream( - ProtocolName::from("/protocol/1"), - Vec::new(), + Arc::from([ProtocolName::from("/protocol/1")]), SubstreamId::new(), permit, SubstreamKeepAlive::Yes, diff --git a/src/protocol/notification/tests/substream_validation.rs b/src/protocol/notification/tests/substream_validation.rs index 27e391815..e5b0f834a 100644 --- a/src/protocol/notification/tests/substream_validation.rs +++ b/src/protocol/notification/tests/substream_validation.rs @@ -137,14 +137,14 @@ async fn substream_accepted() { // protocol asks for outbound substream to be opened and its state is changed accordingly let ProtocolCommand::OpenSubstream { - protocol, + protocols, substream_id, .. } = proto_rx.recv().await.unwrap() else { panic!("invalid commnd received"); }; - assert_eq!(protocol, ProtocolName::from("/notif/1")); + assert_eq!(protocols[0], ProtocolName::from("/notif/1")); assert_eq!(substream_id, SubstreamId::from(0usize)); let expected = SubstreamId::from(0usize); diff --git a/src/protocol/protocol_set.rs b/src/protocol/protocol_set.rs index a4618807d..f23204eaf 100644 --- a/src/protocol/protocol_set.rs +++ b/src/protocol/protocol_set.rs @@ -177,19 +177,13 @@ impl From for TransportEvent { pub enum ProtocolCommand { /// Open substream. OpenSubstream { - /// Protocol name. - protocol: ProtocolName, - - /// Fallback names. + /// Protocols to negotiate. /// - /// If the protocol has changed its name but wishes to support the old name(s), it must - /// provide the old protocol names in `fallback_names`. These are fed into - /// `multistream-select` which them attempts to negotiate a protocol for the substream - /// using one of the provided names and if the substream is negotiated successfully, will - /// report back the actual protocol name that was negotiated, in case the protocol - /// needs to deal with the old version of the protocol in different way compared to - /// the new version. - fallback_names: Vec, + /// `protocols[0]` is the main protocol; the rest are fallbacks fed into + /// `multistream-select`, which attempts to negotiate one of them. The list is + /// precomputed once per protocol handler and shared via [`Arc`] so it can be + /// cheaply forwarded with every substream open. + protocols: Arc<[ProtocolName]>, /// Substream ID. /// @@ -236,7 +230,10 @@ pub struct ProtocolSet { /// Mapping `fallback_name` -> `main_name`. fallback_names: HashMap, /// Connection keep-alive settings for both main & fallback protocol names. - keep_alives: HashMap, + /// + /// Wrapped in [`Arc`] so it can be handed to the listener path on every inbound + /// substream without re-cloning the map. + keep_alives: Arc>, } impl ProtocolSet { @@ -275,7 +272,8 @@ impl ProtocolSet { ) }) .collect::>(); - let keep_alives = main_keep_alives.into_iter().chain(fallback_keep_alives).collect(); + let keep_alives = + Arc::new(main_keep_alives.into_iter().chain(fallback_keep_alives).collect()); ProtocolSet { rx, @@ -310,8 +308,8 @@ impl ProtocolSet { } /// Get the list of all supported protocols with corresponding keep-alive settings. - pub fn protocols_with_keep_alives(&self) -> HashMap { - self.keep_alives.clone() + pub fn protocols_with_keep_alives(&self) -> Arc> { + Arc::clone(&self.keep_alives) } /// Report to `protocol` that substream was opened for `peer`. diff --git a/src/protocol/transport_service.rs b/src/protocol/transport_service.rs index ec06b239a..120ded3a3 100644 --- a/src/protocol/transport_service.rs +++ b/src/protocol/transport_service.rs @@ -286,8 +286,13 @@ pub struct TransportService { /// Protocol. protocol: ProtocolName, - /// Fallback names for the protocol. - fallback_names: Vec, + /// Protocols to negotiate on every outbound substream, with the main protocol at index 0 + /// and fallbacks afterwards. + /// + /// Precomputed once in [`Self::new`] and reused for every substream open. The list is + /// shared with the transport via [`Arc`], avoiding a per-substream allocation in the hot + /// path (see paritytech/litep2p#346). + outbound_protocols: Arc<[ProtocolName]>, /// Open connections. connections: HashMap, @@ -323,12 +328,17 @@ impl TransportService { let keep_alive_tracker = KeepAliveTracker::new(keep_alive_timeout); + let outbound_protocols: Arc<[ProtocolName]> = std::iter::once(protocol.clone()) + .chain(fallback_names) + .collect::>() + .into(); + ( Self { rx, protocol, local_peer_id, - fallback_names, + outbound_protocols, transport_handle, next_substream_id, connections: HashMap::new(), @@ -592,8 +602,7 @@ impl TransportService { connection .open_substream( - self.protocol.clone(), - self.fallback_names.clone(), + Arc::clone(&self.outbound_protocols), substream_id, permit, self.substream_keep_alive, @@ -1426,12 +1435,12 @@ mod tests { let protocol_command = cmd_rx1.recv().await.unwrap(); match protocol_command { ProtocolCommand::OpenSubstream { - protocol, + protocols, substream_id: opened_substream_id, permit, .. } => { - assert_eq!(protocol, ProtocolName::from("/notif/1")); + assert_eq!(protocols[0], ProtocolName::from("/notif/1")); assert_eq!(substream_id, opened_substream_id); // Save the substream permit for later. @@ -1444,12 +1453,12 @@ mod tests { let protocol_command = cmd_rx1.recv().await.unwrap(); match protocol_command { ProtocolCommand::OpenSubstream { - protocol, + protocols, substream_id: opened_substream_id, permit, .. } => { - assert_eq!(protocol, ProtocolName::from("/notif/1")); + assert_eq!(protocols[0], ProtocolName::from("/notif/1")); assert_eq!(second_substream_id, opened_substream_id); // Save the substream permit for later. @@ -1475,12 +1484,12 @@ mod tests { let protocol_command = cmd_rx1.recv().await.unwrap(); match protocol_command { ProtocolCommand::OpenSubstream { - protocol, + protocols, substream_id: opened_substream_id, permit, .. } => { - assert_eq!(protocol, ProtocolName::from("/notif/1")); + assert_eq!(protocols[0], ProtocolName::from("/notif/1")); assert_eq!(substream_id, opened_substream_id); // Save the substream permit for later. @@ -1561,12 +1570,12 @@ mod tests { let protocol_command = cmd_rx1.recv().await.unwrap(); match protocol_command { ProtocolCommand::OpenSubstream { - protocol, + protocols, substream_id: opened_substream_id, permit, .. } => { - assert_eq!(protocol, ProtocolName::from("/notif/1")); + assert_eq!(protocols[0], ProtocolName::from("/notif/1")); assert_eq!(substream_id, opened_substream_id); // Save the substream permit for later. @@ -1579,12 +1588,12 @@ mod tests { let protocol_command = cmd_rx1.recv().await.unwrap(); match protocol_command { ProtocolCommand::OpenSubstream { - protocol, + protocols, substream_id: opened_substream_id, permit, .. } => { - assert_eq!(protocol, ProtocolName::from("/notif/1")); + assert_eq!(protocols[0], ProtocolName::from("/notif/1")); assert_eq!(second_substream_id, opened_substream_id); // Save the substream permit for later. @@ -1619,12 +1628,12 @@ mod tests { let protocol_command = cmd_rx1.recv().await.unwrap(); match protocol_command { ProtocolCommand::OpenSubstream { - protocol, + protocols, substream_id: opened_substream_id, permit, .. } => { - assert_eq!(protocol, ProtocolName::from("/notif/1")); + assert_eq!(protocols[0], ProtocolName::from("/notif/1")); assert_eq!(substream_id, opened_substream_id); // Save the substream permit for later. diff --git a/src/transport/quic/connection.rs b/src/transport/quic/connection.rs index 2d91cac30..34ac8a8e1 100644 --- a/src/transport/quic/connection.rs +++ b/src/transport/quic/connection.rs @@ -20,7 +20,7 @@ //! QUIC connection. -use std::{collections::HashMap, time::Duration}; +use std::{collections::HashMap, sync::Arc, time::Duration}; use crate::{ config::Role, @@ -32,7 +32,10 @@ use crate::{ quic::substream::{NegotiatingSubstream, Substream}, Endpoint, }, - types::{protocol::ProtocolName, SubstreamId}, + types::{ + protocol::{protocol_name_as_str, ProtocolName}, + SubstreamId, + }, BandwidthSink, PeerId, }; @@ -137,12 +140,17 @@ impl QuicConnection { } /// Negotiate protocol. - async fn negotiate_protocol( + async fn negotiate_protocol( stream: S, role: &Role, - protocols: Vec<&str>, - ) -> Result<(Negotiated, ProtocolName), NegotiationError> { - tracing::trace!(target: LOG_TARGET, ?protocols, "negotiating protocols"); + protocols: I, + ) -> Result<(Negotiated, ProtocolName), NegotiationError> + where + S: AsyncRead + AsyncWrite + Unpin, + I: IntoIterator, + I::Item: AsRef<[u8]> + Clone + std::fmt::Display, + { + tracing::trace!(target: LOG_TARGET, "negotiating protocols"); let (protocol, socket) = match role { Role::Dialer => dialer_select_proto(stream, protocols, Version::V1).await, @@ -150,34 +158,33 @@ impl QuicConnection { } .map_err(NegotiationError::MultistreamSelectError)?; - tracing::trace!(target: LOG_TARGET, ?protocol, "protocol negotiated"); + tracing::trace!(target: LOG_TARGET, %protocol, "protocol negotiated"); Ok((socket, ProtocolName::from(protocol.to_string()))) } - /// Open substream for `protocol`. + /// Open substream for `protocols`. async fn open_substream( handle: QuinnConnection, permit: Permit, substream_id: SubstreamId, - protocol: ProtocolName, - fallback_names: Vec, + protocols: Arc<[ProtocolName]>, keep_alive: SubstreamKeepAlive, ) -> Result { - tracing::debug!(target: LOG_TARGET, ?protocol, ?substream_id, "open substream"); + tracing::debug!( + target: LOG_TARGET, + protocol = %protocols[0], + ?substream_id, + "open substream", + ); let stream = match handle.open_bi().await { Ok((send_stream, recv_stream)) => NegotiatingSubstream::new(send_stream, recv_stream), Err(error) => return Err(NegotiationError::Quic(error.into()).into()), }; - // TODO: https://github.com/paritytech/litep2p/issues/346 protocols don't change after - // they've been initialized so this should be done only once - let protocols = std::iter::once(&*protocol) - .chain(fallback_names.iter().map(|protocol| &**protocol)) - .collect(); - - let (io, protocol) = Self::negotiate_protocol(stream, &Role::Dialer, protocols).await?; + let (io, protocol) = + Self::negotiate_protocol(stream, &Role::Dialer, protocols.iter().map(protocol_name_as_str)).await?; tracing::trace!( target: LOG_TARGET, @@ -203,7 +210,7 @@ impl QuicConnection { /// Accept bidirectional substream from rmeote peer. async fn accept_substream( stream: NegotiatingSubstream, - protocols: HashMap, + protocols: Arc>, substream_id: SubstreamId, permit: Permit, ) -> Result { @@ -213,9 +220,9 @@ impl QuicConnection { "accept inbound substream" ); - let protocol_names = protocols.keys().map(|protocol| &**protocol).collect::>(); let (io, protocol) = - Self::negotiate_protocol(stream, &Role::Listener, protocol_names).await?; + Self::negotiate_protocol(stream, &Role::Listener, protocols.keys().map(protocol_name_as_str)) + .await?; let keep_alive = *protocols.get(&protocol).expect("protocol to be one of the keys"); tracing::trace!( @@ -347,8 +354,7 @@ impl QuicConnection { ).await; } Some(ProtocolCommand::OpenSubstream { - protocol, - fallback_names, + protocols, substream_id, permit, keep_alive, @@ -356,11 +362,11 @@ impl QuicConnection { }) => { let connection = self.connection.clone(); let substream_open_timeout = self.substream_open_timeout; + let protocol = protocols[0].clone(); tracing::trace!( target: LOG_TARGET, - ?protocol, - ?fallback_names, + %protocol, ?substream_id, "open substream" ); @@ -372,8 +378,7 @@ impl QuicConnection { connection, permit, substream_id, - protocol.clone(), - fallback_names, + protocols, keep_alive, ), ) diff --git a/src/transport/tcp/connection.rs b/src/transport/tcp/connection.rs index 7f296952d..546ed5b1d 100644 --- a/src/transport/tcp/connection.rs +++ b/src/transport/tcp/connection.rs @@ -33,7 +33,10 @@ use crate::{ tcp::substream::Substream, Endpoint, }, - types::{protocol::ProtocolName, ConnectionId, SubstreamId}, + types::{ + protocol::{protocol_name_as_str, ProtocolName}, + ConnectionId, SubstreamId, + }, BandwidthSink, PeerId, }; @@ -275,11 +278,15 @@ impl TcpConnection { substream_id: SubstreamId, permit: Permit, keep_alive: SubstreamKeepAlive, - protocol: ProtocolName, - fallback_names: Vec, + protocols: Arc<[ProtocolName]>, open_timeout: Duration, ) -> Result { - tracing::debug!(target: LOG_TARGET, ?protocol, ?substream_id, "open substream"); + tracing::debug!( + target: LOG_TARGET, + protocol = %protocols[0], + ?substream_id, + "open substream", + ); let stream = match control.open_stream().await { Ok(stream) => { @@ -300,14 +307,13 @@ impl TcpConnection { } }; - // TODO: https://github.com/paritytech/litep2p/issues/346 protocols don't change after - // they've been initialized so this should be done only once - let protocols = std::iter::once(&*protocol) - .chain(fallback_names.iter().map(|protocol| &**protocol)) - .collect(); - - let (io, protocol) = - Self::negotiate_protocol(stream, &Role::Dialer, protocols, open_timeout).await?; + let (io, protocol) = Self::negotiate_protocol( + stream, + &Role::Dialer, + protocols.iter().map(protocol_name_as_str), + open_timeout, + ) + .await?; Ok(NegotiatedSubstream { io: io.inner(), @@ -360,7 +366,7 @@ impl TcpConnection { stream: crate::yamux::Stream, permit: Permit, substream_id: SubstreamId, - protocols: HashMap, + protocols: Arc>, open_timeout: Duration, ) -> Result { tracing::trace!( @@ -369,9 +375,13 @@ impl TcpConnection { "accept inbound substream", ); - let protocol_names = protocols.keys().map(|protocol| &**protocol).collect::>(); - let (io, protocol) = - Self::negotiate_protocol(stream, &Role::Listener, protocol_names, open_timeout).await?; + let (io, protocol) = Self::negotiate_protocol( + stream, + &Role::Listener, + protocols.keys().map(protocol_name_as_str), + open_timeout, + ) + .await?; let keep_alive = *protocols.get(&protocol).expect("protocol to be one of the keys"); tracing::trace!( @@ -391,13 +401,18 @@ impl TcpConnection { } /// Negotiate protocol. - async fn negotiate_protocol( + async fn negotiate_protocol( stream: S, role: &Role, - protocols: Vec<&str>, + protocols: I, substream_open_timeout: Duration, - ) -> Result<(Negotiated, ProtocolName), NegotiationError> { - tracing::trace!(target: LOG_TARGET, ?protocols, "negotiating protocols"); + ) -> Result<(Negotiated, ProtocolName), NegotiationError> + where + S: AsyncRead + AsyncWrite + Unpin, + I: IntoIterator, + I::Item: AsRef<[u8]> + Clone + std::fmt::Display, + { + tracing::trace!(target: LOG_TARGET, "negotiating protocols"); match tokio::time::timeout(substream_open_timeout, async move { match role { @@ -410,7 +425,7 @@ impl TcpConnection { Err(_) => Err(NegotiationError::Timeout), Ok(Err(error)) => Err(NegotiationError::MultistreamSelectError(error)), Ok(Ok((protocol, socket))) => { - tracing::trace!(target: LOG_TARGET, ?protocol, "protocol negotiated"); + tracing::trace!(target: LOG_TARGET, %protocol, "protocol negotiated"); Ok((socket, ProtocolName::from(protocol.to_string()))) } @@ -681,8 +696,7 @@ impl TcpConnection { ) -> crate::Result { match command { Some(ProtocolCommand::OpenSubstream { - protocol, - fallback_names, + protocols, substream_id, connection_id, permit, @@ -690,10 +704,11 @@ impl TcpConnection { }) => { let control = self.control.clone(); let open_timeout = self.substream_open_timeout; + let protocol = protocols[0].clone(); tracing::trace!( target: LOG_TARGET, - ?protocol, + %protocol, ?substream_id, ?connection_id, "open substream", @@ -707,8 +722,7 @@ impl TcpConnection { substream_id, permit, keep_alive, - protocol.clone(), - fallback_names, + protocols, open_timeout, ), ) diff --git a/src/transport/webrtc/connection.rs b/src/transport/webrtc/connection.rs index d7e31a074..f8bf6b7b0 100644 --- a/src/transport/webrtc/connection.rs +++ b/src/transport/webrtc/connection.rs @@ -68,11 +68,12 @@ const MAX_PENDING_PER_CHANNEL: usize = 16; /// Opening channel context. #[derive(Debug)] struct ChannelContext { - /// Protocol name. - protocol: ProtocolName, - - /// Fallback names. - fallback_names: Vec, + /// Protocols to negotiate. + /// + /// `protocols[0]` is the main protocol; `protocols[1..]` are fallbacks. Shared via + /// [`Arc`] with the originating protocol handler so opening a substream does not + /// require duplicating the list. + protocols: Arc<[ProtocolName]>, /// Substream ID. substream_id: SubstreamId, @@ -305,7 +306,7 @@ impl WebRtcConnection { channel.set_buffered_amount_low_threshold(BACKPRESSURE_THRESHOLD); } - let Some(mut context) = self.pending_outbound.remove(&channel_id) else { + let Some(context) = self.pending_outbound.remove(&channel_id) else { tracing::trace!( target: LOG_TARGET, peer = ?self.peer, @@ -322,9 +323,7 @@ impl WebRtcConnection { return Ok(()); }; - let fallback_names = std::mem::take(&mut context.fallback_names); - let (dialer_state, message) = - WebRtcDialerState::propose(context.protocol.clone(), fallback_names)?; + let (dialer_state, message) = WebRtcDialerState::propose(Arc::clone(&context.protocols))?; let message = WebRtcMessage::encode(message, None); self.write(channel_id, message)?; @@ -488,9 +487,8 @@ impl WebRtcConnection { let payload = WebRtcMessage::decode(&data)?.payload.ok_or(Error::InvalidData)?; let protocols = self.protocol_set.protocols_with_keep_alives(); - let protocol_names = protocols.keys().cloned().collect(); let (response, negotiated) = - match webrtc_listener_negotiate(protocol_names, payload.into(), header_received)? { + match webrtc_listener_negotiate(protocols.keys(), payload.into(), header_received)? { ListenerSelectResult::Accepted { protocol, message } => (message, Some(protocol)), ListenerSelectResult::Rejected { message } | ListenerSelectResult::PendingProtocol { message } => (message, None), @@ -776,7 +774,7 @@ impl WebRtcConnection { context, dialer_state, } => { - let protocol = context.protocol.clone(); + let protocol = context.protocols[0].clone(); let substream_id = context.substream_id; let lifetime_permit = context.keep_alive.then(|| context.opening_permit.clone()); @@ -880,8 +878,7 @@ impl WebRtcConnection { /// Open outbound substream. fn on_open_substream( &mut self, - protocol: ProtocolName, - fallback_names: Vec, + protocols: Arc<[ProtocolName]>, substream_id: SubstreamId, opening_permit: Permit, keep_alive: SubstreamKeepAlive, @@ -891,7 +888,7 @@ impl WebRtcConnection { ordered: false, reliability: Default::default(), negotiated: None, - protocol: protocol.to_string(), + protocol: protocols[0].to_string(), }); tracing::trace!( @@ -899,16 +896,14 @@ impl WebRtcConnection { peer = ?self.peer, ?channel_id, ?substream_id, - ?protocol, - ?fallback_names, + protocol = %protocols[0], "open data channel", ); self.pending_outbound.insert( channel_id, ChannelContext { - protocol, - fallback_names, + protocols, substream_id, opening_permit, keep_alive, @@ -1100,8 +1095,7 @@ impl WebRtcConnection { return self.on_connection_closed().await; } Some(ProtocolCommand::OpenSubstream { - protocol, - fallback_names, + protocols, substream_id, permit, keep_alive, @@ -1114,20 +1108,14 @@ impl WebRtcConnection { tracing::debug!( target: LOG_TARGET, peer = ?self.peer, - ?protocol, + protocol = %protocols[0], is_alive = self.rtc.is_alive(), is_connected = self.rtc.is_connected(), "rejecting substream open: connection not healthy", ); continue; } - self.on_open_substream( - protocol, - fallback_names, - substream_id, - permit, - keep_alive, - ); + self.on_open_substream(protocols, substream_id, permit, keep_alive); } }, _ = tokio::time::sleep(duration) => { diff --git a/src/transport/websocket/connection.rs b/src/transport/websocket/connection.rs index 7c1055390..3513c4883 100644 --- a/src/transport/websocket/connection.rs +++ b/src/transport/websocket/connection.rs @@ -32,7 +32,10 @@ use crate::{ websocket::{stream::BufferedStream, substream::Substream}, Endpoint, }, - types::{protocol::ProtocolName, ConnectionId, SubstreamId}, + types::{ + protocol::{protocol_name_as_str, ProtocolName}, + ConnectionId, SubstreamId, + }, BandwidthSink, PeerId, }; @@ -43,7 +46,7 @@ use tokio_tungstenite::{MaybeTlsStream, WebSocketStream}; use tokio_util::compat::FuturesAsyncReadCompatExt; use url::Url; -use std::{collections::HashMap, time::Duration}; +use std::{collections::HashMap, sync::Arc, time::Duration}; mod schema { pub(super) mod noise { @@ -203,13 +206,18 @@ impl WebSocketConnection { } /// Negotiate protocol. - async fn negotiate_protocol( + async fn negotiate_protocol( stream: S, role: &Role, - protocols: Vec<&str>, + protocols: I, substream_open_timeout: Duration, - ) -> Result<(Negotiated, ProtocolName), NegotiationError> { - tracing::trace!(target: LOG_TARGET, ?protocols, "negotiating protocols"); + ) -> Result<(Negotiated, ProtocolName), NegotiationError> + where + S: AsyncRead + AsyncWrite + Unpin, + I: IntoIterator, + I::Item: AsRef<[u8]> + Clone + std::fmt::Display, + { + tracing::trace!(target: LOG_TARGET, "negotiating protocols"); match tokio::time::timeout(substream_open_timeout, async move { match role { @@ -222,7 +230,7 @@ impl WebSocketConnection { Err(_) => Err(NegotiationError::Timeout), Ok(Err(error)) => Err(NegotiationError::MultistreamSelectError(error)), Ok(Ok((protocol, socket))) => { - tracing::trace!(target: LOG_TARGET, ?protocol, "protocol negotiated"); + tracing::trace!(target: LOG_TARGET, %protocol, "protocol negotiated"); Ok((socket, ProtocolName::from(protocol.to_string()))) } @@ -378,7 +386,7 @@ impl WebSocketConnection { stream: crate::yamux::Stream, permit: Permit, substream_id: SubstreamId, - protocols: HashMap, + protocols: Arc>, substream_open_timeout: Duration, ) -> Result { tracing::trace!( @@ -387,11 +395,10 @@ impl WebSocketConnection { "accept inbound substream" ); - let protocol_names = protocols.keys().map(|protocol| &**protocol).collect::>(); let (io, protocol) = Self::negotiate_protocol( stream, &Role::Listener, - protocol_names, + protocols.keys().map(protocol_name_as_str), substream_open_timeout, ) .await?; @@ -413,17 +420,21 @@ impl WebSocketConnection { }) } - /// Open substream for `protocol`. + /// Open substream for `protocols`. pub async fn open_substream( mut control: crate::yamux::Control, permit: Permit, substream_id: SubstreamId, - protocol: ProtocolName, - fallback_names: Vec, + protocols: Arc<[ProtocolName]>, substream_open_timeout: Duration, keep_alive: SubstreamKeepAlive, ) -> Result { - tracing::debug!(target: LOG_TARGET, ?protocol, ?substream_id, "open substream"); + tracing::debug!( + target: LOG_TARGET, + protocol = %protocols[0], + ?substream_id, + "open substream", + ); let stream = match control.open_stream().await { Ok(stream) => { @@ -444,15 +455,13 @@ impl WebSocketConnection { } }; - // TODO: https://github.com/paritytech/litep2p/issues/346 protocols don't change after - // they've been initialized so this should be done only once - let protocols = std::iter::once(&*protocol) - .chain(fallback_names.iter().map(|protocol| &**protocol)) - .collect(); - - let (io, protocol) = - Self::negotiate_protocol(stream, &Role::Dialer, protocols, substream_open_timeout) - .await?; + let (io, protocol) = Self::negotiate_protocol( + stream, + &Role::Dialer, + protocols.iter().map(protocol_name_as_str), + substream_open_timeout, + ) + .await?; Ok(NegotiatedSubstream { io: io.inner(), @@ -567,8 +576,7 @@ impl WebSocketConnection { } protocol = self.protocol_set.next() => match protocol { Some(ProtocolCommand::OpenSubstream { - protocol, - fallback_names, + protocols, substream_id, permit, keep_alive, @@ -576,10 +584,11 @@ impl WebSocketConnection { }) => { let control = self.control.clone(); let substream_open_timeout = self.substream_open_timeout; + let protocol = protocols[0].clone(); tracing::trace!( target: LOG_TARGET, - ?protocol, + %protocol, ?substream_id, "open substream" ); @@ -591,8 +600,7 @@ impl WebSocketConnection { control, permit, substream_id, - protocol.clone(), - fallback_names, + protocols, substream_open_timeout, keep_alive, ), diff --git a/src/types/protocol.rs b/src/types/protocol.rs index eb64238b5..606df3221 100644 --- a/src/types/protocol.rs +++ b/src/types/protocol.rs @@ -82,6 +82,16 @@ impl std::ops::Deref for ProtocolName { } } +/// Helper that maps `&ProtocolName -> &str` via deref. +/// +/// Defined as a free `fn` (not a closure) so it has a fully `for<'a>`-polymorphic signature. +/// This is required when the resulting `Map` iterator is captured by an `async move` block +/// that is later boxed (`Box::pin`) — closures with locally-inferred lifetimes fail HRTB and +/// produce "implementation of `FnOnce` is not general enough" errors at the boxing site. +pub(crate) fn protocol_name_as_str(p: &ProtocolName) -> &str { + p +} + impl Hash for ProtocolName { fn hash(&self, state: &mut H) { (self as &str).hash(state)