Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 48 additions & 40 deletions src/multistream_select/dialer_select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ use std::{
convert::TryFrom as _,
iter, mem,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};

Expand Down Expand Up @@ -320,11 +321,15 @@ enum HandshakeState {
/// `multistream-select` dialer handshake state.
#[derive(Debug)]
pub struct WebRtcDialerState {
/// Proposed main protocol.
protocol: ProtocolName,
/// Protocols to negotiate in preference order.
///
/// `protocols[0]` is the main protocol; `protocols[1..]` are fallbacks. The list is
/// shared via [`Arc`] so it can be precomputed once per protocol handler and reused
/// across substreams without per-substream allocation.
protocols: Arc<[ProtocolName]>,

/// Fallback names of the main protocol.
fallback_names: Vec<ProtocolName>,
/// Index into [`Self::protocols`] of the protocol currently being proposed.
current_index: usize,

/// Dialer handshake state.
state: HandshakeState,
Expand All @@ -333,15 +338,13 @@ pub struct WebRtcDialerState {
impl WebRtcDialerState {
/// Propose protocol to remote peer.
///
/// `fallback_names` must be in preference order, the first element is the
/// next protocol to try.
/// `protocols[0]` is proposed first; the rest are tried in order via
/// [`Self::propose_next_fallback`] when the peer rejects the current one.
///
/// Return [`WebRtcDialerState`] which is used to drive forward the negotiation and an encoded
/// `multistream-select` message that contains the protocol proposal for the substream.
pub fn propose(
protocol: ProtocolName,
mut fallback_names: Vec<ProtocolName>,
) -> crate::Result<(Self, Vec<u8>)> {
pub fn propose(protocols: Arc<[ProtocolName]>) -> crate::Result<(Self, Vec<u8>)> {
let protocol = protocols.first().ok_or(Error::InvalidData)?;
let message = webrtc_encode_multistream_message(
Message::Protocol(
Protocol::try_from(protocol.as_ref()).map_err(|_| Error::InvalidData)?,
Expand All @@ -351,33 +354,35 @@ impl WebRtcDialerState {
.freeze()
.to_vec();

// Reverse fallback_names so that we can pop from it.
fallback_names.reverse();

Ok((
Self {
protocol,
fallback_names,
protocols,
current_index: 0,
state: HandshakeState::WaitingResponse,
},
message,
))
}

/// Currently proposed protocol.
fn current_protocol(&self) -> &ProtocolName {
&self.protocols[self.current_index]
}

/// Propose the next fallback protocol to the remote peer.
///
/// Returns `None` if there are no more fallback protocols to try.
/// Returns `Some(message)` with the encoded message to send, containing the protocol name.
pub fn propose_next_fallback(&mut self) -> crate::Result<Option<Vec<u8>>> {
let Some(next) = self.fallback_names.pop() else {
if self.current_index + 1 >= self.protocols.len() {
return Ok(None);
};

self.protocol = next;
}
self.current_index += 1;

let message = webrtc_encode_multistream_message(
Message::Protocol(
Protocol::try_from(self.protocol.as_ref()).map_err(|_| Error::InvalidData)?,
Protocol::try_from(self.current_protocol().as_ref())
.map_err(|_| Error::InvalidData)?,
),
false,
)?
Expand Down Expand Up @@ -460,9 +465,9 @@ impl WebRtcDialerState {
return Err(crate::error::NegotiationError::StateMismatch);
}

if self.protocol.as_bytes() == protocol.as_ref() {
if self.current_protocol().as_bytes() == protocol.as_ref() {
check_trailing_bytes(&remaining);
return Ok(HandshakeResult::Succeeded(self.protocol.clone()));
return Ok(HandshakeResult::Succeeded(self.current_protocol().clone()));
}

return Err(crate::error::NegotiationError::MultistreamSelectError(
Expand Down Expand Up @@ -835,7 +840,8 @@ mod tests {
#[test]
fn propose() {
let (mut dialer_state, message) =
WebRtcDialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap();
WebRtcDialerState::propose(Arc::from([ProtocolName::from("/13371338/proto/1")]))
.unwrap();

let mut bytes = BytesMut::with_capacity(32);
bytes.put_u8(MSG_MULTISTREAM_1_0.len() as u8);
Expand All @@ -852,10 +858,10 @@ mod tests {

#[test]
fn propose_with_fallback() {
let (mut dialer_state, message) = WebRtcDialerState::propose(
let (mut dialer_state, message) = WebRtcDialerState::propose(Arc::from([
ProtocolName::from("/13371338/proto/1"),
vec![ProtocolName::from("/sup/proto/1")],
)
ProtocolName::from("/sup/proto/1"),
]))
.unwrap();

// Initial message should only contain the main protocol, not the fallback.
Expand All @@ -874,10 +880,10 @@ mod tests {

#[test]
fn propose_next_fallback() {
let (mut dialer_state, _message) = WebRtcDialerState::propose(
let (mut dialer_state, _message) = WebRtcDialerState::propose(Arc::from([
ProtocolName::from("/13371338/proto/1"),
vec![ProtocolName::from("/sup/proto/1")],
)
ProtocolName::from("/sup/proto/1"),
]))
.unwrap();

// Simulate receiving header-only response, transitioning to WaitingProtocol.
Expand Down Expand Up @@ -943,7 +949,8 @@ mod tests {
message.encode(&mut bytes).map_err(|_| Error::InvalidData).unwrap();

let (mut dialer_state, _message) =
WebRtcDialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap();
WebRtcDialerState::propose(Arc::from([ProtocolName::from("/13371338/proto/1")]))
.unwrap();

match dialer_state.register_response(bytes.freeze().to_vec()) {
Ok(HandshakeResult::NotReady) => {}
Expand All @@ -966,7 +973,8 @@ mod tests {
let response = bytes.freeze().to_vec();

let (mut dialer_state, _message) =
WebRtcDialerState::propose(ProtocolName::from("/13371338/proto/1"), vec![]).unwrap();
WebRtcDialerState::propose(Arc::from([ProtocolName::from("/13371338/proto/1")]))
.unwrap();

match dialer_state.register_response(response) {
Err(error::NegotiationError::MultistreamSelectError(NegotiationError::Failed)) => {}
Expand All @@ -983,10 +991,10 @@ mod tests {
.unwrap()
.freeze();

let (mut dialer_state, _message) = WebRtcDialerState::propose(
let (mut dialer_state, _message) = WebRtcDialerState::propose(Arc::from([
ProtocolName::from("/13371338/proto/1"),
vec![ProtocolName::from("/sup/proto/1")],
)
ProtocolName::from("/sup/proto/1"),
]))
.unwrap();

match dialer_state.register_response(message.to_vec()) {
Expand All @@ -1006,10 +1014,10 @@ mod tests {
.unwrap()
.freeze();

let (mut dialer_state, _message) = WebRtcDialerState::propose(
let (mut dialer_state, _message) = WebRtcDialerState::propose(Arc::from([
ProtocolName::from("/13371338/proto/1"),
vec![ProtocolName::from("/sup/proto/1")],
)
ProtocolName::from("/sup/proto/1"),
]))
.unwrap();

dialer_state.propose_next_fallback();
Expand All @@ -1024,10 +1032,10 @@ mod tests {

#[test]
fn reject_unproposed_fallback_confirmation() {
let (mut dialer_state, _message) = WebRtcDialerState::propose(
let (mut dialer_state, _message) = WebRtcDialerState::propose(Arc::from([
ProtocolName::from("/13371338/proto/1"),
vec![ProtocolName::from("/sup/proto/1")],
)
ProtocolName::from("/sup/proto/1"),
]))
.unwrap();

// The dialer has only proposed the main protocol. The fallback is stored for a
Expand Down
31 changes: 17 additions & 14 deletions src/multistream_select/listener_select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,11 +389,14 @@ fn decode_multistream_message(data: &mut Bytes) -> Result<Message, error::Negoti
/// Parse the protocol offered by the remote peer and check if it matches any locally available
/// protocol. The `header_received` parameter indicates whether the multistream-select header
/// has already been exchanged in a previous round.
pub fn webrtc_listener_negotiate(
supported_protocols: Vec<ProtocolName>,
pub fn webrtc_listener_negotiate<'a, I>(
supported_protocols: I,
mut payload: Bytes,
header_received: bool,
) -> crate::Result<ListenerSelectResult> {
) -> crate::Result<ListenerSelectResult>
where
I: IntoIterator<Item = &'a ProtocolName>,
{
// Save for zero-copy header echo (Bytes::clone is O(1)).
let raw_payload = payload.clone();

Expand Down Expand Up @@ -448,7 +451,7 @@ pub fn webrtc_listener_negotiate(
"listener: checking protocol",
);

for supported in supported_protocols.iter() {
for supported in supported_protocols {
if protocol.as_ref() == supported.as_bytes() {
return Ok(ListenerSelectResult::Accepted {
protocol: supported.clone(),
Expand Down Expand Up @@ -494,7 +497,7 @@ mod tests {
.unwrap()
.freeze();

match webrtc_listener_negotiate(local_protocols, message, false) {
match webrtc_listener_negotiate(&local_protocols, message, false) {
Err(error) => panic!("error received: {error:?}"),
Ok(ListenerSelectResult::Rejected { .. }) => panic!("message rejected"),
Ok(ListenerSelectResult::PendingProtocol { .. }) => panic!("unexpected pending"),
Expand Down Expand Up @@ -523,7 +526,7 @@ mod tests {
.unwrap()
.freeze();

match webrtc_listener_negotiate(local_protocols, message, false) {
match webrtc_listener_negotiate(&local_protocols, message, false) {
Err(error) => assert!(std::matches!(
error,
Error::NegotiationError(error::NegotiationError::ParseError(
Expand All @@ -549,7 +552,7 @@ mod tests {
Message::Header(HeaderLine::V1).encode(&mut bytes).unwrap();
let payload = Bytes::from(UnsignedVarint::encode(bytes).unwrap());

match webrtc_listener_negotiate(local_protocols, payload.clone(), false) {
match webrtc_listener_negotiate(&local_protocols, payload.clone(), false) {
Ok(ListenerSelectResult::PendingProtocol { message }) => {
assert_eq!(message, payload);
}
Expand All @@ -574,7 +577,7 @@ mod tests {
.unwrap();
let payload = Bytes::from(UnsignedVarint::encode(bytes).unwrap());

match webrtc_listener_negotiate(local_protocols, payload, false) {
match webrtc_listener_negotiate(&local_protocols, payload, false) {
Err(error) => assert!(std::matches!(
error,
Error::NegotiationError(error::NegotiationError::MultistreamSelectError(
Expand All @@ -601,7 +604,7 @@ mod tests {
.unwrap()
.freeze();

match webrtc_listener_negotiate(local_protocols, message, false) {
match webrtc_listener_negotiate(&local_protocols, message, false) {
Err(error) => panic!("error received: {error:?}"),
Ok(ListenerSelectResult::Rejected { message }) => {
assert_eq!(
Expand All @@ -625,7 +628,7 @@ mod tests {
Message::Header(HeaderLine::V1).encode(&mut bytes).unwrap();
let header_payload = Bytes::from(UnsignedVarint::encode(bytes).unwrap());

match webrtc_listener_negotiate(local_protocols.clone(), header_payload.clone(), false) {
match webrtc_listener_negotiate(&local_protocols, header_payload.clone(), false) {
Ok(ListenerSelectResult::PendingProtocol { message }) => {
assert_eq!(message, header_payload);
}
Expand All @@ -639,7 +642,7 @@ mod tests {
.unwrap();
let proto1_payload = Bytes::from(UnsignedVarint::encode(bytes).unwrap());

match webrtc_listener_negotiate(local_protocols.clone(), proto1_payload, true) {
match webrtc_listener_negotiate(&local_protocols, proto1_payload, true) {
Ok(ListenerSelectResult::Rejected { message }) => {
assert_eq!(
message,
Expand All @@ -658,7 +661,7 @@ mod tests {
.unwrap();
let proto2_payload = Bytes::from(UnsignedVarint::encode(bytes).unwrap());

match webrtc_listener_negotiate(local_protocols, proto2_payload, true) {
match webrtc_listener_negotiate(&local_protocols, proto2_payload, true) {
Ok(ListenerSelectResult::Rejected { message }) => {
assert_eq!(
message,
Expand All @@ -680,7 +683,7 @@ mod tests {
Message::Header(HeaderLine::V1).encode(&mut bytes).unwrap();
let header_payload = Bytes::from(UnsignedVarint::encode(bytes).unwrap());

match webrtc_listener_negotiate(local_protocols.clone(), header_payload.clone(), false) {
match webrtc_listener_negotiate(&local_protocols, header_payload.clone(), false) {
Ok(ListenerSelectResult::PendingProtocol { message }) => {
assert_eq!(message, header_payload);
}
Expand All @@ -694,7 +697,7 @@ mod tests {
.unwrap();
let proto_payload = Bytes::from(UnsignedVarint::encode(bytes).unwrap());

match webrtc_listener_negotiate(local_protocols, proto_payload, true) {
match webrtc_listener_negotiate(&local_protocols, proto_payload, true) {
Ok(ListenerSelectResult::Accepted { protocol, .. }) => {
assert_eq!(protocol, ProtocolName::from("/13371338/proto/1"));
}
Expand Down
25 changes: 12 additions & 13 deletions src/protocol/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ use crate::{

use tokio::sync::mpsc::{error::TrySendError, Sender, WeakSender};

use std::sync::Arc;

/// Connection type, from the point of view of the protocol.
#[derive(Debug, Clone)]
enum ConnectionType {
Expand Down Expand Up @@ -111,12 +113,14 @@ impl ConnectionHandle {
}
}

/// Open substream to remote peer over `protocol` and send the acquired permit to the
/// Open substream to remote peer over `protocols` and send the acquired permit to the
/// transport so it can be given to the opened substream.
///
/// `protocols[0]` is the main protocol; the rest are fallbacks. The list is shared
/// via [`Arc`] so callers reuse the same precomputed list for every substream.
pub fn open_substream(
&mut self,
protocol: ProtocolName,
fallback_names: Vec<ProtocolName>,
protocols: Arc<[ProtocolName]>,
substream_id: SubstreamId,
permit: Permit,
keep_alive: SubstreamKeepAlive,
Expand All @@ -127,8 +131,7 @@ impl ConnectionHandle {
inactive.upgrade().ok_or(SubstreamError::ConnectionClosed)?,
}
.try_send(ProtocolCommand::OpenSubstream {
protocol: protocol.clone(),
fallback_names,
protocols,
substream_id,
connection_id: self.connection_id,
permit,
Expand Down Expand Up @@ -214,8 +217,7 @@ mod tests {
let permit = handle.try_get_permit().unwrap();

let result = handle.open_substream(
ProtocolName::from("/protocol/1"),
Vec::new(),
Arc::from([ProtocolName::from("/protocol/1")]),
SubstreamId::new(),
permit,
SubstreamKeepAlive::Yes,
Expand All @@ -234,8 +236,7 @@ mod tests {
drop(_rx);

let result = handle.open_substream(
ProtocolName::from("/protocol/1"),
Vec::new(),
Arc::from([ProtocolName::from("/protocol/1")]),
SubstreamId::new(),
permit,
SubstreamKeepAlive::Yes,
Expand All @@ -252,8 +253,7 @@ mod tests {
let permit = handle.try_get_permit().unwrap();

let result = handle.open_substream(
ProtocolName::from("/protocol/1"),
Vec::new(),
Arc::from([ProtocolName::from("/protocol/1")]),
SubstreamId::new(),
permit,
SubstreamKeepAlive::Yes,
Expand All @@ -262,8 +262,7 @@ mod tests {

let permit = handle.try_get_permit().unwrap();
match handle.open_substream(
ProtocolName::from("/protocol/1"),
Vec::new(),
Arc::from([ProtocolName::from("/protocol/1")]),
SubstreamId::new(),
permit,
SubstreamKeepAlive::Yes,
Expand Down
Loading