Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions src/asgi/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ pub(crate) struct ASGIWebsocketProtocol {
init_tx: Arc<atomic::AtomicBool>,
init_event: Arc<Notify>,
closed: Arc<atomic::AtomicBool>,
teardown: Arc<Notify>,
}

impl ASGIWebsocketProtocol {
Expand All @@ -378,6 +379,7 @@ impl ASGIWebsocketProtocol {
init_tx: Arc::new(false.into()),
init_event: Arc::new(Notify::new()),
closed: Arc::new(false.into()),
teardown: Arc::new(Notify::new()),
}
}

Expand Down Expand Up @@ -464,17 +466,22 @@ impl ASGIWebsocketProtocol {
fn send_message<'p>(&self, py: Python<'p>, data: Message) -> PyResult<Bound<'p, PyAny>> {
let transport = self.ws_tx.clone();
let closed = self.closed.clone();
let teardown = self.teardown.clone();

future_into_py_futlike(self.rt.clone(), py, async move {
if let Some(ws) = &mut *(transport.lock().await) {
match ws.send(data).await {
Ok(()) => return FutureResultToPy::None,
_ => {
if closed.load(atomic::Ordering::Acquire) {
log::info!("Attempted to write to a closed websocket");
return FutureResultToPy::None;
tokio::select! {
biased;
res = ws.send(data) => match res {
Ok(()) => return FutureResultToPy::None,
_ => {
if closed.load(atomic::Ordering::Acquire) {
log::info!("Attempted to write to a closed websocket");
return FutureResultToPy::None;
}
}
}
},
() = teardown.notified() => return FutureResultToPy::None,
}
}
FutureResultToPy::Err(error_flow!("Transport not initialized or closed"))
Expand Down Expand Up @@ -510,11 +517,13 @@ impl ASGIWebsocketProtocol {
Option<oneshot::Sender<WebsocketDetachedTransport>>,
WebsocketDetachedTransport,
) {
let mut ws_rx = self.ws_rx.blocking_lock();
self.closed.store(true, atomic::Ordering::Release);
self.teardown.notify_one();
let mut ws_tx = self.ws_tx.blocking_lock();
let ws_rx = self.ws_rx.try_lock().map_or(None, |mut guard| guard.take());
(
self.tx.lock().unwrap().take(),
WebsocketDetachedTransport::new(self.consumed(), ws_rx.take(), ws_tx.take(), None),
WebsocketDetachedTransport::new(self.consumed(), ws_rx, ws_tx.take(), None),
)
}
}
Expand Down