|
@@ -30,9 +30,9 @@ use std::{
|
|
|
use tokio::{
|
|
|
runtime::Handle,
|
|
|
select,
|
|
|
- sync::{broadcast, Mutex},
|
|
|
+ sync::{broadcast, Mutex, OwnedSemaphorePermit, Semaphore},
|
|
|
};
|
|
|
-use tokio_util::codec::{Encoder, Framed, FramedParts};
|
|
|
+use tokio_util::codec::{Encoder, Framed, FramedParts, FramedWrite};
|
|
|
|
|
|
/// Returns a [Receiver] bound to the given [IpAddr] which receives messages at the bind path of
|
|
|
/// the given [Writecap] of the given credentials. The returned type can be used to make
|
|
@@ -215,7 +215,7 @@ pub trait Transmitter {
|
|
|
T: 'call + SendMsg<'call>;
|
|
|
|
|
|
/// Transmit a message to the connected [Receiver] without waiting for a reply.
|
|
|
- fn send<'call, T: 'call + SendMsg<'call>>(&'call mut self, msg: T) -> Self::SendFut<'call, T>;
|
|
|
+ fn send<'call, T: 'call + SendMsg<'call>>(&'call self, msg: T) -> Self::SendFut<'call, T>;
|
|
|
|
|
|
type CallFut<'call, T, F>: 'call + Future<Output = Result<F::Return>> + Send
|
|
|
where
|
|
@@ -225,7 +225,7 @@ pub trait Transmitter {
|
|
|
|
|
|
/// Transmit a message to the connected [Receiver], waits for a reply, then calls the given
|
|
|
/// [DeserCallback] with the deserialized reply.
|
|
|
- fn call<'call, T, F>(&'call mut self, msg: T, callback: F) -> Self::CallFut<'call, T, F>
|
|
|
+ fn call<'call, T, F>(&'call self, msg: T, callback: F) -> Self::CallFut<'call, T, F>
|
|
|
where
|
|
|
T: 'call + CallMsg<'call>,
|
|
|
F: 'static + Send + Sync + DeserCallback;
|
|
@@ -308,7 +308,7 @@ impl<T: Serialize> Encoder<T> for MsgEncoder {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-type FramedMsg = Framed<SendStream, MsgEncoder>;
|
|
|
+type FramedMsg = FramedWrite<SendStream, MsgEncoder>;
|
|
|
type ArcMutex<T> = Arc<Mutex<T>>;
|
|
|
|
|
|
#[derive(Clone)]
|
|
@@ -335,10 +335,10 @@ struct MsgRecvdCallback<F> {
|
|
|
}
|
|
|
|
|
|
impl<F: MsgCallback> MsgRecvdCallback<F> {
|
|
|
- fn new(path: Arc<BlockPath>, framed_msg: ArcMutex<FramedMsg>, inner: F) -> Self {
|
|
|
+ fn new(path: Arc<BlockPath>, framed_msg: FramedMsg, inner: F) -> Self {
|
|
|
Self {
|
|
|
path,
|
|
|
- replier: Replier::new(framed_msg),
|
|
|
+ replier: Replier::new(Arc::new(Mutex::new(framed_msg))),
|
|
|
inner,
|
|
|
}
|
|
|
}
|
|
@@ -411,6 +411,9 @@ struct QuicReceiver {
|
|
|
}
|
|
|
|
|
|
impl QuicReceiver {
|
|
|
+ /// This defines the maximum number of blocking tasks which can be spawned at once.
|
|
|
+ const BLOCKING_LIMIT: usize = 16;
|
|
|
+
|
|
|
fn new<F: 'static + MsgCallback>(
|
|
|
recv_addr: Arc<BlockAddr>,
|
|
|
resolver: Arc<CertResolver>,
|
|
@@ -433,62 +436,72 @@ impl QuicReceiver {
|
|
|
callback: F,
|
|
|
mut stop_rx: broadcast::Receiver<()>,
|
|
|
) {
|
|
|
+ let blocking_permits = Arc::new(Semaphore::new(Self::BLOCKING_LIMIT));
|
|
|
loop {
|
|
|
let connecting = await_or_stop!(endpoint.accept(), stop_rx.recv());
|
|
|
let connection = unwrap_or_continue!(connecting.await, |err| error!(
|
|
|
"error accepting QUIC connection: {err}"
|
|
|
));
|
|
|
- let callback = callback.clone();
|
|
|
- let stop_rx = stop_rx.resubscribe();
|
|
|
- // spawn_blocking is used to allow the user supplied callback to to block without
|
|
|
- // disrupting the main thread pool.
|
|
|
- tokio::task::spawn_blocking(move || {
|
|
|
- Handle::current().block_on(Self::handle_connection(connection, callback, stop_rx))
|
|
|
- });
|
|
|
+ tokio::spawn(Self::handle_connection(
|
|
|
+ connection,
|
|
|
+ callback.clone(),
|
|
|
+ stop_rx.resubscribe(),
|
|
|
+ blocking_permits.clone(),
|
|
|
+ ));
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- async fn handle_connection<F: MsgCallback>(
|
|
|
+ async fn handle_connection<F: 'static + MsgCallback>(
|
|
|
connection: Connection,
|
|
|
callback: F,
|
|
|
mut stop_rx: broadcast::Receiver<()>,
|
|
|
+ blocking_permits: Arc<Semaphore>,
|
|
|
) {
|
|
|
let client_path = unwrap_or_return!(
|
|
|
Self::client_path(connection.peer_identity()),
|
|
|
|err| error!("failed to get client path from peer identity: {err}")
|
|
|
);
|
|
|
- let mut frame_parts_opt: Option<FramedParts<SendStream, MsgEncoder>> = None;
|
|
|
loop {
|
|
|
let result = await_or_stop!(connection.accept_bi().map(Some), stop_rx.recv());
|
|
|
let (send_stream, recv_stream) =
|
|
|
unwrap_or_continue!(result, |err| error!("error accepting stream: {err}"));
|
|
|
- let frame_parts = match frame_parts_opt {
|
|
|
- Some(mut frame_parts) => {
|
|
|
- frame_parts.io = send_stream;
|
|
|
- frame_parts
|
|
|
- }
|
|
|
- None => FramedParts::new::<<<F as MsgCallback>::Arg<'_> as CallMsg>::Reply<'_>>(
|
|
|
+ let permit = unwrap_or_continue!(blocking_permits.clone().acquire_owned().await);
|
|
|
+ let client_path = client_path.clone();
|
|
|
+ let callback = callback.clone();
|
|
|
+ // spawn_blocking is used to allow the user supplied callback to to block without
|
|
|
+ // disrupting the main thread pool.
|
|
|
+ tokio::task::spawn_blocking(move || {
|
|
|
+ Handle::current().block_on(Self::handle_message(
|
|
|
+ client_path,
|
|
|
send_stream,
|
|
|
- MsgEncoder::new(),
|
|
|
- ),
|
|
|
- };
|
|
|
- let framed_msg = Arc::new(Mutex::new(FramedMsg::from_parts(frame_parts)));
|
|
|
- let callback =
|
|
|
- MsgRecvdCallback::new(client_path.clone(), framed_msg.clone(), callback.clone());
|
|
|
- let mut msg_stream = CallbackFramed::new(recv_stream);
|
|
|
- let result = msg_stream
|
|
|
- .next(callback)
|
|
|
- .await
|
|
|
- .ok_or_else(|| bterr!("client closed stream before sending a message"));
|
|
|
- let msg_framed = Arc::try_unwrap(framed_msg).unwrap();
|
|
|
- let msg_framed = msg_framed.into_inner();
|
|
|
- frame_parts_opt = Some(msg_framed.into_parts());
|
|
|
- match unwrap_or_continue!(result) {
|
|
|
- Err(err) => error!("msg_stream produced an error: {err}"),
|
|
|
- Ok(result) => {
|
|
|
- if let Err(err) = result {
|
|
|
- error!("callback returned an error: {err}");
|
|
|
- }
|
|
|
+ recv_stream,
|
|
|
+ permit,
|
|
|
+ callback,
|
|
|
+ ))
|
|
|
+ });
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ async fn handle_message<F: MsgCallback>(
|
|
|
+ client_path: Arc<BlockPath>,
|
|
|
+ send_stream: SendStream,
|
|
|
+ recv_stream: RecvStream,
|
|
|
+ // This argument must be kept alive until this method returns.
|
|
|
+ _permit: OwnedSemaphorePermit,
|
|
|
+ callback: F,
|
|
|
+ ) {
|
|
|
+ let framed_msg = FramedWrite::new(send_stream, MsgEncoder::new());
|
|
|
+ let callback = MsgRecvdCallback::new(client_path.clone(), framed_msg, callback.clone());
|
|
|
+ let mut msg_stream = CallbackFramed::new(recv_stream);
|
|
|
+ let result = msg_stream
|
|
|
+ .next(callback)
|
|
|
+ .await
|
|
|
+ .ok_or_else(|| bterr!("client closed stream before sending a message"));
|
|
|
+ match unwrap_or_return!(result) {
|
|
|
+ Err(err) => error!("msg_stream produced an error: {err}"),
|
|
|
+ Ok(result) => {
|
|
|
+ if let Err(err) = result {
|
|
|
+ error!("callback returned an error: {err}");
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -575,7 +588,7 @@ impl QuicTransmitter {
|
|
|
})
|
|
|
}
|
|
|
|
|
|
- async fn transmit<T: Serialize>(&mut self, envelope: Envelope<T>) -> Result<RecvStream> {
|
|
|
+ async fn transmit<T: Serialize>(&self, envelope: Envelope<T>) -> Result<RecvStream> {
|
|
|
let mut guard = self.send_parts.lock().await;
|
|
|
let (send_stream, recv_stream) = self.connection.open_bi().await?;
|
|
|
let parts = match guard.take() {
|
|
@@ -593,7 +606,7 @@ impl QuicTransmitter {
|
|
|
Ok(recv_stream)
|
|
|
}
|
|
|
|
|
|
- async fn call<'ser, T, F>(&'ser mut self, msg: T, callback: F) -> Result<F::Return>
|
|
|
+ async fn call<'ser, T, F>(&'ser self, msg: T, callback: F) -> Result<F::Return>
|
|
|
where
|
|
|
T: 'ser + CallMsg<'ser>,
|
|
|
F: 'static + Send + Sync + DeserCallback,
|
|
@@ -621,7 +634,7 @@ impl Transmitter for QuicTransmitter {
|
|
|
type SendFut<'ser, T> = impl 'ser + Future<Output = Result<()>> + Send
|
|
|
where T: 'ser + SendMsg<'ser>;
|
|
|
|
|
|
- fn send<'ser, 'r, T: 'ser + SendMsg<'ser>>(&'ser mut self, msg: T) -> Self::SendFut<'ser, T> {
|
|
|
+ fn send<'ser, 'r, T: 'ser + SendMsg<'ser>>(&'ser self, msg: T) -> Self::SendFut<'ser, T> {
|
|
|
self.transmit(Envelope::send(msg))
|
|
|
.map(|result| result.map(|_| ()))
|
|
|
}
|
|
@@ -632,7 +645,7 @@ impl Transmitter for QuicTransmitter {
|
|
|
T: 'ser + CallMsg<'ser>,
|
|
|
F: 'static + Send + Sync + DeserCallback;
|
|
|
|
|
|
- fn call<'ser, T, F>(&'ser mut self, msg: T, callback: F) -> Self::CallFut<'ser, T, F>
|
|
|
+ fn call<'ser, T, F>(&'ser self, msg: T, callback: F) -> Self::CallFut<'ser, T, F>
|
|
|
where
|
|
|
T: 'ser + CallMsg<'ser>,
|
|
|
F: 'static + Send + Sync + DeserCallback,
|