Explorar el Código

Finished modifying btmsg to allow for greater concurrency
in message processing, though the concurrency on the client
side is still limited.

Matthew Carr hace 2 años
padre
commit
e14675f852
Se han modificado 3 ficheros con 76 adiciones y 50 borrados
  1. 1 1
      crates/btmsg/src/callback_framed.rs
  2. 59 46
      crates/btmsg/src/lib.rs
  3. 16 3
      crates/btmsg/tests/tests.rs

+ 1 - 1
crates/btmsg/src/callback_framed.rs

@@ -51,7 +51,7 @@ impl<I> CallbackFramed<I> {
         if slice.len() < payload_len {
             return Ok(DecodeStatus::Reserve(payload_len - slice.len()));
         }
-        Ok(DecodeStatus::Consume( Self::FRAME_LEN_SZ + payload_len ))
+        Ok(DecodeStatus::Consume(Self::FRAME_LEN_SZ + payload_len))
     }
 }
 

+ 59 - 46
crates/btmsg/src/lib.rs

@@ -30,9 +30,9 @@ use std::{
 use tokio::{
     runtime::Handle,
     select,
-    sync::{broadcast, Mutex},
+    sync::{broadcast, Mutex, OwnedSemaphorePermit, Semaphore},
 };
-use tokio_util::codec::{Encoder, Framed, FramedParts};
+use tokio_util::codec::{Encoder, Framed, FramedParts, FramedWrite};
 
 /// Returns a [Receiver] bound to the given [IpAddr] which receives messages at the bind path of
 /// the given [Writecap] of the given credentials. The returned type can be used to make
@@ -215,7 +215,7 @@ pub trait Transmitter {
         T: 'call + SendMsg<'call>;
 
     /// Transmit a message to the connected [Receiver] without waiting for a reply.
-    fn send<'call, T: 'call + SendMsg<'call>>(&'call mut self, msg: T) -> Self::SendFut<'call, T>;
+    fn send<'call, T: 'call + SendMsg<'call>>(&'call self, msg: T) -> Self::SendFut<'call, T>;
 
     type CallFut<'call, T, F>: 'call + Future<Output = Result<F::Return>> + Send
     where
@@ -225,7 +225,7 @@ pub trait Transmitter {
 
     /// Transmit a message to the connected [Receiver], waits for a reply, then calls the given
     /// [DeserCallback] with the deserialized reply.
-    fn call<'call, T, F>(&'call mut self, msg: T, callback: F) -> Self::CallFut<'call, T, F>
+    fn call<'call, T, F>(&'call self, msg: T, callback: F) -> Self::CallFut<'call, T, F>
     where
         T: 'call + CallMsg<'call>,
         F: 'static + Send + Sync + DeserCallback;
@@ -308,7 +308,7 @@ impl<T: Serialize> Encoder<T> for MsgEncoder {
     }
 }
 
-type FramedMsg = Framed<SendStream, MsgEncoder>;
+type FramedMsg = FramedWrite<SendStream, MsgEncoder>;
 type ArcMutex<T> = Arc<Mutex<T>>;
 
 #[derive(Clone)]
@@ -335,10 +335,10 @@ struct MsgRecvdCallback<F> {
 }
 
 impl<F: MsgCallback> MsgRecvdCallback<F> {
-    fn new(path: Arc<BlockPath>, framed_msg: ArcMutex<FramedMsg>, inner: F) -> Self {
+    fn new(path: Arc<BlockPath>, framed_msg: FramedMsg, inner: F) -> Self {
         Self {
             path,
-            replier: Replier::new(framed_msg),
+            replier: Replier::new(Arc::new(Mutex::new(framed_msg))),
             inner,
         }
     }
@@ -411,6 +411,9 @@ struct QuicReceiver {
 }
 
 impl QuicReceiver {
+    /// This defines the maximum number of blocking tasks which can be spawned at once.
+    const BLOCKING_LIMIT: usize = 16;
+
     fn new<F: 'static + MsgCallback>(
         recv_addr: Arc<BlockAddr>,
         resolver: Arc<CertResolver>,
@@ -433,62 +436,72 @@ impl QuicReceiver {
         callback: F,
         mut stop_rx: broadcast::Receiver<()>,
     ) {
+        let blocking_permits = Arc::new(Semaphore::new(Self::BLOCKING_LIMIT));
         loop {
             let connecting = await_or_stop!(endpoint.accept(), stop_rx.recv());
             let connection = unwrap_or_continue!(connecting.await, |err| error!(
                 "error accepting QUIC connection: {err}"
             ));
-            let callback = callback.clone();
-            let stop_rx = stop_rx.resubscribe();
-            // spawn_blocking is used to allow the user supplied callback to to block without
-            // disrupting the main thread pool.
-            tokio::task::spawn_blocking(move || {
-                Handle::current().block_on(Self::handle_connection(connection, callback, stop_rx))
-            });
+            tokio::spawn(Self::handle_connection(
+                connection,
+                callback.clone(),
+                stop_rx.resubscribe(),
+                blocking_permits.clone(),
+            ));
         }
     }
 
-    async fn handle_connection<F: MsgCallback>(
+    async fn handle_connection<F: 'static + MsgCallback>(
         connection: Connection,
         callback: F,
         mut stop_rx: broadcast::Receiver<()>,
+        blocking_permits: Arc<Semaphore>,
     ) {
         let client_path = unwrap_or_return!(
             Self::client_path(connection.peer_identity()),
             |err| error!("failed to get client path from peer identity: {err}")
         );
-        let mut frame_parts_opt: Option<FramedParts<SendStream, MsgEncoder>> = None;
         loop {
             let result = await_or_stop!(connection.accept_bi().map(Some), stop_rx.recv());
             let (send_stream, recv_stream) =
                 unwrap_or_continue!(result, |err| error!("error accepting stream: {err}"));
-            let frame_parts = match frame_parts_opt {
-                Some(mut frame_parts) => {
-                    frame_parts.io = send_stream;
-                    frame_parts
-                }
-                None => FramedParts::new::<<<F as MsgCallback>::Arg<'_> as CallMsg>::Reply<'_>>(
+            let permit = unwrap_or_continue!(blocking_permits.clone().acquire_owned().await);
+            let client_path = client_path.clone();
+            let callback = callback.clone();
+            // spawn_blocking is used to allow the user supplied callback to to block without
+            // disrupting the main thread pool.
+            tokio::task::spawn_blocking(move || {
+                Handle::current().block_on(Self::handle_message(
+                    client_path,
                     send_stream,
-                    MsgEncoder::new(),
-                ),
-            };
-            let framed_msg = Arc::new(Mutex::new(FramedMsg::from_parts(frame_parts)));
-            let callback =
-                MsgRecvdCallback::new(client_path.clone(), framed_msg.clone(), callback.clone());
-            let mut msg_stream = CallbackFramed::new(recv_stream);
-            let result = msg_stream
-                .next(callback)
-                .await
-                .ok_or_else(|| bterr!("client closed stream before sending a message"));
-            let msg_framed = Arc::try_unwrap(framed_msg).unwrap();
-            let msg_framed = msg_framed.into_inner();
-            frame_parts_opt = Some(msg_framed.into_parts());
-            match unwrap_or_continue!(result) {
-                Err(err) => error!("msg_stream produced an error: {err}"),
-                Ok(result) => {
-                    if let Err(err) = result {
-                        error!("callback returned an error: {err}");
-                    }
+                    recv_stream,
+                    permit,
+                    callback,
+                ))
+            });
+        }
+    }
+
+    async fn handle_message<F: MsgCallback>(
+        client_path: Arc<BlockPath>,
+        send_stream: SendStream,
+        recv_stream: RecvStream,
+        // This argument must be kept alive until this method returns.
+        _permit: OwnedSemaphorePermit,
+        callback: F,
+    ) {
+        let framed_msg = FramedWrite::new(send_stream, MsgEncoder::new());
+        let callback = MsgRecvdCallback::new(client_path.clone(), framed_msg, callback.clone());
+        let mut msg_stream = CallbackFramed::new(recv_stream);
+        let result = msg_stream
+            .next(callback)
+            .await
+            .ok_or_else(|| bterr!("client closed stream before sending a message"));
+        match unwrap_or_return!(result) {
+            Err(err) => error!("msg_stream produced an error: {err}"),
+            Ok(result) => {
+                if let Err(err) = result {
+                    error!("callback returned an error: {err}");
                 }
             }
         }
@@ -575,7 +588,7 @@ impl QuicTransmitter {
         })
     }
 
-    async fn transmit<T: Serialize>(&mut self, envelope: Envelope<T>) -> Result<RecvStream> {
+    async fn transmit<T: Serialize>(&self, envelope: Envelope<T>) -> Result<RecvStream> {
         let mut guard = self.send_parts.lock().await;
         let (send_stream, recv_stream) = self.connection.open_bi().await?;
         let parts = match guard.take() {
@@ -593,7 +606,7 @@ impl QuicTransmitter {
         Ok(recv_stream)
     }
 
-    async fn call<'ser, T, F>(&'ser mut self, msg: T, callback: F) -> Result<F::Return>
+    async fn call<'ser, T, F>(&'ser self, msg: T, callback: F) -> Result<F::Return>
     where
         T: 'ser + CallMsg<'ser>,
         F: 'static + Send + Sync + DeserCallback,
@@ -621,7 +634,7 @@ impl Transmitter for QuicTransmitter {
     type SendFut<'ser, T> = impl 'ser + Future<Output = Result<()>> + Send
         where T: 'ser + SendMsg<'ser>;
 
-    fn send<'ser, 'r, T: 'ser + SendMsg<'ser>>(&'ser mut self, msg: T) -> Self::SendFut<'ser, T> {
+    fn send<'ser, 'r, T: 'ser + SendMsg<'ser>>(&'ser self, msg: T) -> Self::SendFut<'ser, T> {
         self.transmit(Envelope::send(msg))
             .map(|result| result.map(|_| ()))
     }
@@ -632,7 +645,7 @@ impl Transmitter for QuicTransmitter {
         T: 'ser + CallMsg<'ser>,
         F: 'static + Send + Sync + DeserCallback;
 
-    fn call<'ser, T, F>(&'ser mut self, msg: T, callback: F) -> Self::CallFut<'ser, T, F>
+    fn call<'ser, T, F>(&'ser self, msg: T, callback: F) -> Self::CallFut<'ser, T, F>
     where
         T: 'ser + CallMsg<'ser>,
         F: 'static + Send + Sync + DeserCallback,

+ 16 - 3
crates/btmsg/tests/tests.rs

@@ -8,6 +8,7 @@ use btlib::{
 };
 use core::future::{ready, Future, Ready};
 use ctor::ctor;
+use futures::join;
 use lazy_static::lazy_static;
 use serde::{Deserialize, Serialize};
 use std::{
@@ -200,7 +201,7 @@ macro_rules! recv {
 #[tokio::test]
 async fn message_received_is_message_sent() {
     let (sender, mut passed) = mpsc::channel(1);
-    let (mut sender, _receiver) = proc_tx_rx(Delegate::new(
+    let (sender, _receiver) = proc_tx_rx(Delegate::new(
         sender,
         |msg: MsgReceived<Msg<'_>>, sender: Sender<bool>| {
             let passed = if let Msg::Ping = msg.body() {
@@ -225,7 +226,7 @@ async fn message_received_is_message_sent() {
 #[tokio::test]
 async fn message_received_from_path_is_correct() {
     let (sender, mut path) = mpsc::channel(1);
-    let (mut sender, receiver) = proc_tx_rx(Delegate::new(
+    let (sender, receiver) = proc_tx_rx(Delegate::new(
         sender,
         |msg: MsgReceived<Msg<'_>>, sender: Sender<Arc<BlockPath>>| {
             let path = msg.from().clone();
@@ -396,7 +397,7 @@ async fn read_server() -> (impl Transmitter, impl Receiver) {
 
 #[tokio::test]
 async fn call_with_lifetime() {
-    let (mut sender, _receiver) = read_server().await;
+    let (sender, _receiver) = read_server().await;
 
     let correct_one = sender
         .call(Read { offset: 2, size: 3 }, ReadChecker::new(&[2, 3, 4]))
@@ -410,3 +411,15 @@ async fn call_with_lifetime() {
     assert!(correct_one);
     assert!(correct_two);
 }
+
+#[tokio::test]
+async fn call_concurrently() {
+    let (sender, _receiver) = read_server().await;
+
+    let call_one = sender.call(Read { offset: 2, size: 3 }, ReadChecker::new(&[2, 3, 4]));
+    let call_two = sender.call(Read { offset: 0, size: 2 }, ReadChecker::new(&[0, 1]));
+    let (result_one, result_two) = join!(call_one, call_two);
+
+    assert!(result_one.unwrap());
+    assert!(result_two.unwrap());
+}