Bladeren bron

Reworked btmsg to enable zero copy message processing.

Matthew Carr 2 jaren geleden
bovenliggende
commit
245699fe79

+ 1 - 1
crates/btfuse/src/main.rs

@@ -72,7 +72,7 @@ struct FuseDaemon<'a> {
 }
 
 impl<'a> FuseDaemon<'a> {
-    fn new(path: &'a Path, tabrmd_config: &'a str) -> FuseDaemon<'_> {
+    fn new(path: &'a Path, tabrmd_config: &'a str) -> FuseDaemon<'a> {
         FuseDaemon {
             path,
             tabrmd_config,

+ 6 - 6
crates/btlib/src/blocktree.rs

@@ -162,11 +162,11 @@ mod private {
     /// A trait for types which can render authorization decisions.
     pub trait Authorizer {
         /// Returns [Ok] if read authorization is granted, and [Err] otherwise.
-        fn can_read<'a>(&self, ctx: &AuthzContext<'a>) -> io::Result<()>;
+        fn can_read(&self, ctx: &AuthzContext<'_>) -> io::Result<()>;
         /// Returns [Ok] if write authorization is granted, and [Err] otherwise.
-        fn can_write<'a>(&self, ctx: &AuthzContext<'a>) -> io::Result<()>;
+        fn can_write(&self, ctx: &AuthzContext<'_>) -> io::Result<()>;
         /// Returns [Ok] if execute authorization is granted, and [Err] otherwise.
-        fn can_exec<'a>(&self, ctx: &AuthzContext<'a>) -> io::Result<()>;
+        fn can_exec(&self, ctx: &AuthzContext<'_>) -> io::Result<()>;
     }
 
     trait AuthorizerExt: Authorizer {
@@ -204,7 +204,7 @@ mod private {
     }
 
     impl Authorizer for ModeAuthorizer {
-        fn can_read<'a>(&self, ctx: &AuthzContext<'a>) -> io::Result<()> {
+        fn can_read(&self, ctx: &AuthzContext<'_>) -> io::Result<()> {
             if Self::user_is_root(ctx) {
                 return Ok(());
             }
@@ -215,7 +215,7 @@ mod private {
             Self::authorize(secrets.mode, mask, "read access denied")
         }
 
-        fn can_write<'a>(&self, ctx: &AuthzContext<'a>) -> io::Result<()> {
+        fn can_write(&self, ctx: &AuthzContext<'_>) -> io::Result<()> {
             if Self::user_is_root(ctx) {
                 return Ok(());
             }
@@ -226,7 +226,7 @@ mod private {
             Self::authorize(secrets.mode, mask, "write access denied")
         }
 
-        fn can_exec<'a>(&self, ctx: &AuthzContext<'a>) -> io::Result<()> {
+        fn can_exec(&self, ctx: &AuthzContext<'_>) -> io::Result<()> {
             if Self::user_is_root(ctx) {
                 return Ok(());
             }

+ 5 - 6
crates/btlib/src/crypto.rs

@@ -128,18 +128,15 @@ impl Display for Error {
             }
             Error::IndexOutOfBounds { index, limit } => write!(
                 f,
-                "index {} is out of bounds, it must be strictly less than {}",
-                index, limit
+                "index {index} is out of bounds, it must be strictly less than {limit}",
             ),
             Error::IndivisibleSize { divisor, actual } => write!(
                 f,
-                "expected a size which is divisible by {} but got {}",
-                divisor, actual
+                "expected a size which is divisible by {divisor} but got {actual}",
             ),
             Error::InvalidOffset { actual, limit } => write!(
                 f,
-                "offset {} is out of bounds, it must be strictly less than {}",
-                actual, limit
+                "offset {actual} is out of bounds, it must be strictly less than {limit}",
             ),
             Error::HashCmpFailure => write!(f, "hash data are not equal"),
             Error::RootHashNotVerified => write!(f, "root hash is not verified"),
@@ -493,6 +490,7 @@ pub enum VarHash {
     Sha2_512(Sha2_512),
 }
 
+#[allow(clippy::derivable_impls)]
 impl Default for HashKind {
     fn default() -> HashKind {
         HashKind::Sha2_256
@@ -960,6 +958,7 @@ impl Decrypter for SymKey {
     }
 }
 
+#[allow(clippy::derivable_impls)]
 impl Default for SymKeyKind {
     fn default() -> Self {
         SymKeyKind::Aes256Ctr

+ 160 - 0
crates/btmsg/src/callback_framed.rs

@@ -0,0 +1,160 @@
+use btlib::{error::BoxInIoErr, Result};
+use btserde::{from_slice, read_from};
+use bytes::BytesMut;
+use futures::Future;
+use serde::Deserialize;
+use tokio::io::{AsyncRead, AsyncReadExt};
+
+pub struct CallbackFramed<I> {
+    io: I,
+    buffer: BytesMut,
+}
+
+impl<I> CallbackFramed<I> {
+    const INIT_CAPACITY: usize = 4096;
+
+    pub fn new(inner: I) -> Self {
+        Self {
+            io: inner,
+            buffer: BytesMut::with_capacity(Self::INIT_CAPACITY),
+        }
+    }
+
+    async fn decode<'de, F: 'de + DeserCallback>(
+        mut slice: &'de [u8],
+        callback: &'de F,
+    ) -> Result<DecodeStatus<F::Return>> {
+        let payload_len: u64 = match read_from(&mut slice) {
+            Ok(payload_len) => payload_len,
+            Err(err) => {
+                return match err {
+                    btserde::Error::Eof => Ok(DecodeStatus::None),
+                    btserde::Error::Io(ref io_err) => match io_err.kind() {
+                        std::io::ErrorKind::UnexpectedEof => Ok(DecodeStatus::None),
+                        _ => Err(err.into()),
+                    },
+                    _ => Err(err.into()),
+                }
+            }
+        };
+        let payload_len: usize = payload_len.try_into().box_err()?;
+        if slice.len() < payload_len {
+            return Ok(DecodeStatus::Reserve(payload_len - slice.len()));
+        }
+        let msg: F::Arg<'de> = from_slice(slice)?;
+        let returned = callback.call(msg).await;
+        Ok(DecodeStatus::Some {
+            returned,
+            consumed: std::mem::size_of::<u64>() + payload_len,
+        })
+    }
+}
+
+macro_rules! attempt {
+    ($result:expr) => {
+        match $result {
+            Ok(value) => value,
+            Err(err) => return Some(Err(err.into())),
+        }
+    };
+}
+
+impl<S: AsyncRead + Unpin> CallbackFramed<S> {
+    pub async fn next<F: DeserCallback>(&mut self, callback: F) -> Option<Result<F::Return>> {
+        loop {
+            if self.buffer.capacity() - self.buffer.len() == 0 {
+                // If there is no space left in the buffer we reserve additional bytes to ensure
+                // read_buf doesn't return 0 unless we're at EOF.
+                self.buffer.reserve(1);
+            }
+            let read_ct = attempt!(self.io.read_buf(&mut self.buffer).await);
+            if 0 == read_ct {
+                return None;
+            }
+            match attempt!(Self::decode(&self.buffer[..read_ct], &callback).await) {
+                DecodeStatus::None => continue,
+                DecodeStatus::Reserve(count) => {
+                    self.buffer.reserve(count);
+                    continue;
+                }
+                DecodeStatus::Some { returned, consumed } => {
+                    let _ = self.buffer.split_to(consumed);
+                    return Some(Ok(returned));
+                }
+            }
+        }
+    }
+}
+
+enum DecodeStatus<R> {
+    None,
+    Reserve(usize),
+    Some { returned: R, consumed: usize },
+}
+
+pub trait DeserCallback: Clone {
+    type Arg<'de>: Deserialize<'de> + Send
+    where
+        Self: 'de;
+    type Return;
+    type CallFut<'s>: Future<Output = Self::Return> + Send
+    where
+        Self: 's;
+    fn call<'de>(&'de self, arg: Self::Arg<'de>) -> Self::CallFut<'de>;
+}
+
+impl<F: DeserCallback> DeserCallback for &F {
+    type Arg<'de> = F::Arg<'de> where Self: 'de;
+    type Return = F::Return;
+    type CallFut<'f> = F::CallFut<'f> where Self: 'f;
+    fn call<'de>(&'de self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
+        (*self).call(arg)
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    use crate::MsgEncoder;
+
+    use futures::{future::Ready, SinkExt};
+    use serde::Serialize;
+    use std::io::{Cursor, Seek};
+    use tokio_util::codec::FramedWrite;
+
+    #[derive(Serialize, Deserialize)]
+    struct Msg<'a>(&'a [u8]);
+
+    #[tokio::test]
+    async fn read_single_message() {
+        macro_rules! test_data {
+            () => {
+                b"fulcrum"
+            };
+        }
+
+        #[derive(Clone)]
+        struct TestCb;
+
+        impl DeserCallback for TestCb {
+            type Arg<'de> = Msg<'de> where Self: 'de;
+            type Return = bool;
+            type CallFut<'f> = Ready<bool>;
+
+            fn call<'de>(&'de self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
+                futures::future::ready(arg.0 == test_data!())
+            }
+        }
+
+        let mut write = FramedWrite::new(Cursor::new(Vec::<u8>::new()), MsgEncoder);
+        write.send(Msg(test_data!())).await.unwrap();
+        let mut io = write.into_inner();
+        io.rewind().unwrap();
+        let mut read = CallbackFramed::new(io);
+
+        let matched = read.next(TestCb).await.unwrap().unwrap();
+
+        assert!(matched)
+    }
+}

+ 185 - 217
crates/btmsg/src/lib.rs

@@ -2,23 +2,14 @@
 
 mod tls;
 use tls::*;
+mod callback_framed;
+use callback_framed::{CallbackFramed, DeserCallback};
 
 use btlib::{bterr, crypto::Creds, error::BoxInIoErr, BlockPath, Result, Writecap};
 use btserde::{read_from, write_to};
 use bytes::{BufMut, BytesMut};
-use core::{
-    future::Future,
-    marker::Send,
-    ops::DerefMut,
-    pin::Pin,
-    task::{Context, Poll},
-};
-use futures::{
-    future::{ready, Ready},
-    sink::Send as SendFut,
-    stream::Stream,
-    SinkExt, StreamExt,
-};
+use core::{future::Future, marker::Send, ops::DerefMut, pin::Pin};
+use futures::{sink::Send as SendFut, SinkExt, StreamExt};
 use log::error;
 use quinn::{Connection, Endpoint, RecvStream, SendStream};
 use serde::{de::DeserializeOwned, Deserialize, Serialize};
@@ -27,27 +18,69 @@ use std::{
     collections::hash_map::DefaultHasher,
     hash::{Hash, Hasher},
     marker::PhantomData,
-    net::{IpAddr, SocketAddr},
+    net::{IpAddr, Ipv6Addr, SocketAddr},
     sync::Arc,
 };
 use tokio::{
     select,
-    sync::{broadcast, mpsc, Mutex},
+    sync::{broadcast, Mutex},
 };
-use tokio_stream::wrappers::ReceiverStream;
 use tokio_util::codec::{Decoder, Encoder, Framed, FramedParts, FramedRead, FramedWrite};
 
-/// Returns a [Router] which can be used to make a [Receiver] for the given path and
-///  [Sender] instances for any path.
-pub fn router<C: 'static + Creds + Send + Sync>(
+/// Returns a [Receiver] bound to the given [IpAddr] which receives messages at the bind path of
+/// the given [Writecap] of the given credentials. The returned type can be used to make
+/// [Transmitter]s for any path.
+pub fn receiver<C: 'static + Creds + Send + Sync, F: 'static + MsgCallback>(
     ip_addr: IpAddr,
     creds: Arc<C>,
-) -> Result<impl Router> {
+    callback: F,
+) -> Result<impl Receiver> {
     let writecap = creds.writecap().ok_or(btlib::BlockError::MissingWritecap)?;
     let addr = Arc::new(BlockAddr::new(ip_addr, Arc::new(writecap.bind_path())));
-    QuicRouter::new(addr, Arc::new(CertResolver::new(creds)?))
+    QuicReceiver::new(addr, Arc::new(CertResolver::new(creds)?), callback)
 }
 
+pub async fn transmitter<C: 'static + Creds + Send + Sync>(
+    addr: Arc<BlockAddr>,
+    creds: Arc<C>,
+) -> Result<impl Transmitter> {
+    let resolver = Arc::new(CertResolver::new(creds)?);
+    let endpoint = Endpoint::client(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0))?;
+    QuicTransmitter::from_endpoint(endpoint, addr, resolver).await
+}
+
+pub trait MsgCallback: Clone + Send + Sync + Unpin {
+    type Arg<'de>: CallMsg<'de>
+    where
+        Self: 'de;
+    type Return;
+    type CallFut<'de>: Future<Output = Self::Return> + Send
+    where
+        Self: 'de;
+    fn call<'de>(&'de self, arg: MsgReceived<Self::Arg<'de>>) -> Self::CallFut<'de>;
+}
+
+impl<T: MsgCallback> MsgCallback for &T {
+    type Arg<'de> = T::Arg<'de> where Self: 'de;
+    type Return = T::Return;
+    type CallFut<'de> = T::CallFut<'de> where Self: 'de;
+    fn call<'de>(&'de self, arg: MsgReceived<Self::Arg<'de>>) -> Self::CallFut<'de> {
+        (*self).call(arg)
+    }
+}
+
+/// Trait for messages which can be transmitted using the call method.
+pub trait CallMsg<'de>: Serialize + Deserialize<'de> + Send + Sync {
+    type Reply: Serialize + DeserializeOwned + Send;
+}
+
+#[derive(Serialize, Deserialize)]
+pub enum NoReply {}
+
+/// Trait for messages which can be transmitted using the send method.
+/// Types which implement this trait should specify [NoReply] as their reply type.
+pub trait SendMsg<'de>: CallMsg<'de> {}
+
 /// An address which identifies a block on the network. An instance of this struct can be
 /// used to get a socket address for the block this address refers to.
 #[derive(PartialEq, Eq, Hash, Clone, Debug)]
@@ -87,24 +120,6 @@ impl BlockAddr {
     }
 }
 
-/// Trait for messages which can be transmitted using the call method.
-pub trait CallTx: Serialize + Send {
-    type Reply: 'static + DeserializeOwned + Send;
-}
-
-/// Trait for messages which can be transmitted using the send method.
-/// Types which implement this trait should choose `()` as their reply type.
-pub trait SendTx: CallTx {}
-
-/// Trait for messages which are received from the call method.
-pub trait CallRx: 'static + DeserializeOwned + Send {
-    type Reply<'a>: 'a + Serialize + Send;
-}
-
-/// Trait for messages which are received from the send method.
-/// Types which implement this trait should choose `()` as their reply type.
-pub trait SendRx: CallRx {}
-
 /// Indicates whether a message was sent using `call` or `send`.
 #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Hash, Clone)]
 enum MsgKind {
@@ -144,11 +159,11 @@ impl<T> Envelope<T> {
 pub struct MsgReceived<T> {
     from: Arc<BlockPath>,
     msg: Envelope<T>,
-    replier: Replier,
+    replier: Option<Replier>,
 }
 
 impl<T> MsgReceived<T> {
-    fn new(from: Arc<BlockPath>, msg: Envelope<T>, replier: Replier) -> Self {
+    fn new(from: Arc<BlockPath>, msg: Envelope<T>, replier: Option<Replier>) -> Self {
         Self { from, msg, replier }
     }
 
@@ -164,44 +179,53 @@ impl<T> MsgReceived<T> {
 
     /// Returns true if and only if this messages needs to be replied to.
     pub fn needs_reply(&self) -> bool {
-        self.replier.parts.is_some()
+        self.replier.is_some()
     }
 }
 
-impl<T: CallRx> MsgReceived<T> {
-    /// Replies to this message. This method must be called exactly once for [CallRx] messages
-    /// and exactly zero times for [SendRx] messages. The `needs_reply` method can be called
-    /// on this instance to determine if a reply still needs to be sent.
-    pub async fn reply<'a>(&mut self, reply: T::Reply<'a>) -> Result<()> {
-        self.replier.reply(reply).await
+impl<'de, T: CallMsg<'de>> MsgReceived<T> {
+    /// Returns a type which can be used to reply to this message, if this message requires a
+    /// reply and it has not yet been sent.
+    pub fn take_replier(&mut self) -> Option<Replier> {
+        self.replier.take()
     }
 }
 
-/// A type which can be used to receive messages.
-pub trait Receiver<T>: Stream<Item = Result<MsgReceived<T>>> {
+/// Trait for receiving messages and creating [Transmitter]s.
+pub trait Receiver {
     /// The address at which messages will be received.
-    fn addr(&self) -> &BlockAddr;
+    fn addr(&self) -> &Arc<BlockAddr>;
+
+    type Transmitter: Transmitter + Send;
+    type TransmitterFut<'a>: 'a + Future<Output = Result<Self::Transmitter>> + Send
+    where
+        Self: 'a;
+
+    /// Creates a [Transmitter] which is connected to the given address.
+    fn transmitter(&self, addr: Arc<BlockAddr>) -> Self::TransmitterFut<'_>;
 }
 
 /// A type which can be used to transmit messages.
 pub trait Transmitter {
-    type SendFut<'a, T>: 'a + Future<Output = Result<()>> + Send
+    type SendFut<'s, T>: 's + Future<Output = Result<()>> + Send
     where
-        Self: 'a,
-        T: 'a + Serialize + Send;
+        Self: 's,
+        T: 's + Serialize + Send;
 
     /// Transmit a message to the connected [Receiver] without waiting for a reply.
-    fn send<'a, T: 'a + SendTx>(&'a mut self, msg: T) -> Self::SendFut<'a, T>;
+    fn send<'de, T: 'de + SendMsg<'de>>(&'de mut self, msg: T) -> Self::SendFut<'de, T>;
 
-    type CallFut<'a, T>: 'a + Future<Output = Result<T::Reply>> + Send
+    type CallFut<'s, 'de, T>: 's + Future<Output = Result<T::Reply>> + Send
     where
-        Self: 'a,
-        T: 'a + CallTx;
+        Self: 's,
+        T: 's + CallMsg<'de>,
+        T::Reply: 's;
 
     /// Transmit a message to the connected [Receiver] and wait for a reply.
-    fn call<'a, T>(&'a mut self, msg: T) -> Self::CallFut<'a, T>
+    fn call<'s, 'de, T>(&'s mut self, msg: T) -> Self::CallFut<'s, 'de, T>
     where
-        T: 'a + CallTx;
+        T: 's + CallMsg<'de>,
+        T::Reply: 's;
 
     type FinishFut: Future<Output = Result<()>> + Send;
 
@@ -209,27 +233,7 @@ pub trait Transmitter {
     fn finish(self) -> Self::FinishFut;
 
     /// Returns the address that this instance is transmitting to.
-    fn addr(&self) -> &BlockAddr;
-}
-
-/// Trait for types which can create [Transmitter]s and [Receiver]s.
-pub trait Router {
-    type Transmitter: Transmitter + Send;
-    type TransmitterFut<'a>: 'a + Future<Output = Result<Self::Transmitter>> + Send
-    where
-        Self: 'a;
-
-    /// Creates a [Transmitter] which is connected to the given address.
-    fn transmitter(&self, addr: Arc<BlockAddr>) -> Self::TransmitterFut<'_>;
-
-    type Receiver<T: CallRx>: Receiver<T> + Send + Unpin;
-    type ReceiverFut<'a, T>: 'a + Future<Output = Result<Self::Receiver<T>>> + Send
-    where
-        T: CallRx,
-        Self: 'a;
-
-    /// Creates a [Receiver] which will receive message at the address of this [Router].
-    fn receiver<T: CallRx>(&self) -> Self::ReceiverFut<'_, T>;
+    fn addr(&self) -> &Arc<BlockAddr>;
 }
 
 /// Encodes messages using [btserde].
@@ -299,82 +303,60 @@ impl<T: DeserializeOwned> Decoder for MsgDecoder<T> {
     }
 }
 
-struct QuicRouter {
-    recv_addr: Arc<BlockAddr>,
-    resolver: Arc<CertResolver>,
-    endpoint: Endpoint,
-}
-
-impl QuicRouter {
-    fn new(recv_addr: Arc<BlockAddr>, resolver: Arc<CertResolver>) -> Result<Self> {
-        let socket_addr = recv_addr.socket_addr();
-        let endpoint = Endpoint::server(server_config(resolver.clone())?, socket_addr)?;
-        Ok(Self {
-            endpoint,
-            resolver,
-            recv_addr,
-        })
-    }
-}
-
-impl Router for QuicRouter {
-    type Receiver<T: CallRx> = QuicReceiver<T>;
-    type ReceiverFut<'a, T: CallRx> = Ready<Result<QuicReceiver<T>>>;
-
-    fn receiver<T: CallRx>(&self) -> Self::ReceiverFut<'_, T> {
-        ready(QuicReceiver::new(
-            self.endpoint.clone(),
-            self.recv_addr.clone(),
-        ))
-    }
-
-    type Transmitter = QuicSender;
-    type TransmitterFut<'a> = Pin<Box<dyn 'a + Future<Output = Result<QuicSender>> + Send>>;
-
-    fn transmitter(&self, addr: Arc<BlockAddr>) -> Self::TransmitterFut<'_> {
-        Box::pin(async {
-            QuicSender::from_endpoint(self.endpoint.clone(), addr, self.resolver.clone()).await
-        })
-    }
-}
-
 type SharedFrameParts = Arc<Mutex<Option<FramedParts<SendStream, MsgEncoder>>>>;
 
 #[derive(Clone)]
-struct Replier {
-    parts: Option<SharedFrameParts>,
+pub struct Replier {
+    parts: SharedFrameParts,
 }
 
 impl Replier {
     fn new(send_stream: SendStream) -> Self {
         let parts = FramedParts::new::<()>(send_stream, MsgEncoder::new());
-        let parts = Some(Arc::new(Mutex::new(Some(parts))));
+        let parts = Arc::new(Mutex::new(Some(parts)));
         Self { parts }
     }
 
-    fn empty() -> Self {
-        Self { parts: None }
+    pub async fn reply<T: Serialize + Send>(self, reply: T) -> Result<()> {
+        let parts = self.parts;
+        let mut guard = parts.lock().await;
+        // We must ensure the parts are put back before we leave this block.
+        let parts = guard.take().unwrap();
+        let mut stream = Framed::from_parts(parts);
+        let result = stream.send(reply).await;
+        *guard = Some(stream.into_parts());
+        result
     }
+}
 
-    async fn reply<T: Serialize + Send>(&mut self, reply: T) -> Result<()> {
-        let parts = self
-            .parts
-            .take()
-            .ok_or_else(|| bterr!("reply has already been sent"))?;
-        let result = {
-            let mut guard = parts.lock().await;
-            // We must ensure the parts are put back before we leave this block.
-            let parts = guard.take().unwrap();
-            let mut stream = Framed::from_parts(parts);
-            let result = stream.send(reply).await;
-            *guard = Some(stream.into_parts());
-            result
-        };
-        if result.is_err() {
-            // If the result is an error put back the parts so the caller may try again.
-            self.parts = Some(parts);
+#[derive(Clone)]
+struct MsgRecvdCallback<F> {
+    path: Arc<BlockPath>,
+    replier: Replier,
+    inner: F,
+}
+
+impl<F: MsgCallback> MsgRecvdCallback<F> {
+    fn new(path: Arc<BlockPath>, replier: Replier, inner: F) -> Self {
+        Self {
+            path,
+            replier,
+            inner,
         }
-        result
+    }
+}
+
+impl<F: MsgCallback> DeserCallback for MsgRecvdCallback<F> {
+    type Arg<'de> = Envelope<F::Arg<'de>> where Self: 'de;
+    type Return = F::Return;
+    type CallFut<'s> = F::CallFut<'s> where F: 's;
+    fn call<'de>(&'de self, arg: Envelope<F::Arg<'de>>) -> Self::CallFut<'de> {
+        let replier = match arg.kind {
+            MsgKind::Call => Some(self.replier.clone()),
+            MsgKind::Send => None,
+        };
+        self.inner
+            .call(MsgReceived::new(self.path.clone(), arg, replier))
     }
 }
 
@@ -423,30 +405,34 @@ macro_rules! await_or_stop {
     };
 }
 
-struct QuicReceiver<T> {
+struct QuicReceiver {
     recv_addr: Arc<BlockAddr>,
     stop_tx: broadcast::Sender<()>,
-    stream: ReceiverStream<Result<MsgReceived<T>>>,
+    endpoint: Endpoint,
+    resolver: Arc<CertResolver>,
 }
 
-impl<T: CallRx> QuicReceiver<T> {
-    /// The size of the buffer to store received messages in.
-    const MSG_BUF_SZ: usize = 64;
-
-    fn new(endpoint: Endpoint, recv_addr: Arc<BlockAddr>) -> Result<Self> {
+impl QuicReceiver {
+    fn new<F: 'static + MsgCallback>(
+        recv_addr: Arc<BlockAddr>,
+        resolver: Arc<CertResolver>,
+        callback: F,
+    ) -> Result<Self> {
+        let socket_addr = recv_addr.socket_addr();
+        let endpoint = Endpoint::server(server_config(resolver.clone())?, socket_addr)?;
         let (stop_tx, stop_rx) = broadcast::channel(1);
-        let (msg_tx, msg_rx) = mpsc::channel(Self::MSG_BUF_SZ);
-        tokio::spawn(Self::server_loop(endpoint, msg_tx, stop_rx));
+        tokio::spawn(Self::server_loop(endpoint.clone(), callback, stop_rx));
         Ok(Self {
             recv_addr,
             stop_tx,
-            stream: ReceiverStream::new(msg_rx),
+            endpoint,
+            resolver,
         })
     }
 
-    async fn server_loop(
+    async fn server_loop<F: 'static + MsgCallback>(
         endpoint: Endpoint,
-        msg_tx: mpsc::Sender<Result<MsgReceived<T>>>,
+        callback: F,
         mut stop_rx: broadcast::Receiver<()>,
     ) {
         loop {
@@ -456,29 +442,15 @@ impl<T: CallRx> QuicReceiver<T> {
             ));
             tokio::spawn(Self::handle_connection(
                 connection,
-                msg_tx.clone(),
+                callback.clone(),
                 stop_rx.resubscribe(),
             ));
         }
     }
 
-    /// Returns the path the client is bound to.
-    fn client_path(peer_identity: Option<Box<dyn Any>>) -> Result<Arc<BlockPath>> {
-        let peer_identity =
-            peer_identity.ok_or_else(|| bterr!("connection did not contain a peer identity"))?;
-        let client_certs = peer_identity
-            .downcast::<Vec<rustls::Certificate>>()
-            .map_err(|_| bterr!("failed to downcast peer_identity to certificate chain"))?;
-        let first = client_certs
-            .first()
-            .ok_or_else(|| bterr!("no certificates were presented by the client"))?;
-        let (writecap, ..) = Writecap::from_cert_chain(first, &client_certs[1..])?;
-        Ok(Arc::new(writecap.bind_path()))
-    }
-
-    async fn handle_connection(
+    async fn handle_connection<F: MsgCallback>(
         connection: Connection,
-        msg_tx: mpsc::Sender<Result<MsgReceived<T>>>,
+        callback: F,
         mut stop_rx: broadcast::Receiver<()>,
     ) {
         let client_path = unwrap_or_return!(
@@ -490,28 +462,32 @@ impl<T: CallRx> QuicReceiver<T> {
             |err| error!("error accepting receive stream: {err}")
         );
         let replier = Replier::new(send_stream);
-        let mut msg_stream = FramedRead::new(recv_stream, MsgDecoder::<Envelope<T>>::new());
+        let callback = MsgRecvdCallback::new(client_path, replier, callback);
+        let mut msg_stream = CallbackFramed::new(recv_stream);
         loop {
-            let decode_result = await_or_stop!(msg_stream.next(), stop_rx.recv());
+            let decode_result = await_or_stop!(msg_stream.next(callback.clone()), stop_rx.recv());
             if let Err(ref err) = decode_result {
                 error!("msg_stream produced an error: {err}");
             }
-            let msg_received = decode_result.map(|envelope| {
-                let replier = match envelope.kind {
-                    MsgKind::Call => replier.clone(),
-                    MsgKind::Send => Replier::empty(),
-                };
-                MsgReceived::new(client_path.clone(), envelope, replier)
-            });
-            let send_result = msg_tx.send(msg_received).await;
-            if let Err(err) = send_result {
-                error!("error sending message to mpsc queue: {err}");
-            }
         }
     }
+
+    /// Returns the path the client is bound to.
+    fn client_path(peer_identity: Option<Box<dyn Any>>) -> Result<Arc<BlockPath>> {
+        let peer_identity =
+            peer_identity.ok_or_else(|| bterr!("connection did not contain a peer identity"))?;
+        let client_certs = peer_identity
+            .downcast::<Vec<rustls::Certificate>>()
+            .map_err(|_| bterr!("failed to downcast peer_identity to certificate chain"))?;
+        let first = client_certs
+            .first()
+            .ok_or_else(|| bterr!("no certificates were presented by the client"))?;
+        let (writecap, ..) = Writecap::from_cert_chain(first, &client_certs[1..])?;
+        Ok(Arc::new(writecap.bind_path()))
+    }
 }
 
-impl<T> Drop for QuicReceiver<T> {
+impl Drop for QuicReceiver {
     fn drop(&mut self) {
         // This result will be a failure if the tasks have already returned, which is not a
         // problem.
@@ -519,27 +495,28 @@ impl<T> Drop for QuicReceiver<T> {
     }
 }
 
-impl<T: CallRx> Stream for QuicReceiver<T> {
-    type Item = Result<MsgReceived<T>>;
-
-    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
-        self.stream.poll_next_unpin(cx)
+impl Receiver for QuicReceiver {
+    fn addr(&self) -> &Arc<BlockAddr> {
+        &self.recv_addr
     }
-}
 
-impl<T: CallRx> Receiver<T> for QuicReceiver<T> {
-    fn addr(&self) -> &BlockAddr {
-        &self.recv_addr
+    type Transmitter = QuicTransmitter;
+    type TransmitterFut<'a> = Pin<Box<dyn 'a + Future<Output = Result<QuicTransmitter>> + Send>>;
+
+    fn transmitter(&self, addr: Arc<BlockAddr>) -> Self::TransmitterFut<'_> {
+        Box::pin(async {
+            QuicTransmitter::from_endpoint(self.endpoint.clone(), addr, self.resolver.clone()).await
+        })
     }
 }
 
-struct QuicSender {
+struct QuicTransmitter {
     addr: Arc<BlockAddr>,
     sink: FramedWrite<SendStream, MsgEncoder>,
     recv_stream: Mutex<RecvStream>,
 }
 
-impl QuicSender {
+impl QuicTransmitter {
     async fn from_endpoint(
         endpoint: Endpoint,
         addr: Arc<BlockAddr>,
@@ -567,15 +544,15 @@ impl QuicSender {
 /// https://github.com/rust-lang/rust/issues/63063
 /// feature lands the future types in this implementation should be rewritten to
 /// use it.
-impl Transmitter for QuicSender {
-    fn addr(&self) -> &BlockAddr {
+impl Transmitter for QuicTransmitter {
+    fn addr(&self) -> &Arc<BlockAddr> {
         &self.addr
     }
 
-    type SendFut<'a, T> = SendFut<'a, FramedWrite<SendStream, MsgEncoder>, Envelope<T>>
-        where T: 'a + Serialize + Send;
+    type SendFut<'s, T> = SendFut<'s, FramedWrite<SendStream, MsgEncoder>, Envelope<T>>
+        where T: 's + Serialize + Send;
 
-    fn send<'a, T: 'a + Serialize + Send>(&'a mut self, msg: T) -> Self::SendFut<'a, T> {
+    fn send<'de, T: 'de + SendMsg<'de>>(&'de mut self, msg: T) -> Self::SendFut<'de, T> {
         self.sink.send(Envelope::send(msg))
     }
 
@@ -588,13 +565,15 @@ impl Transmitter for QuicSender {
         })
     }
 
-    type CallFut<'a, T> = Pin<Box<dyn 'a + Future<Output = Result<T::Reply>> + Send>>
+    type CallFut<'s, 'de, T> = Pin<Box<dyn 's + Future<Output = Result<T::Reply>> + Send>>
     where
-        T: 'a + CallTx;
+        T: 's + CallMsg<'de>,
+        T::Reply: 's;
 
-    fn call<'a, T>(&'a mut self, msg: T) -> Self::CallFut<'a, T>
+    fn call<'s, 'de, T>(&'s mut self, msg: T) -> Self::CallFut<'s, 'de, T>
     where
-        T: 'a + CallTx,
+        T: 's + CallMsg<'de>,
+        T::Reply: 's,
     {
         Box::pin(async move {
             self.sink.send(Envelope::call(msg)).await?;
@@ -607,14 +586,3 @@ impl Transmitter for QuicSender {
         })
     }
 }
-
-/// This is an identify function which allows you to specify a type parameter for the output
-/// of a future.
-/// TODO: This was needed to work around a failure in type inference for types with higher-rank
-/// lifetimes. Once this issue is resolved this can be removed:
-/// https://github.com/rust-lang/rust/issues/102211
-pub fn assert_send<'a, T>(
-    fut: impl 'a + Future<Output = T> + Send,
-) -> impl 'a + Future<Output = T> + Send {
-    fut
-}

+ 182 - 232
crates/btmsg/tests/tests.rs

@@ -1,23 +1,25 @@
 use btmsg::*;
 
 use btlib::{
-    crypto::{ConcreteCreds, Creds},
-    Epoch, Principal, Principaled, Result,
+    crypto::{ConcreteCreds, Creds, CredsPriv},
+    BlockPath, Epoch, Principal, Principaled,
 };
+use core::future::Future;
 use ctor::ctor;
-use futures::stream::StreamExt;
 use lazy_static::lazy_static;
 use serde::{Deserialize, Serialize};
 use std::{
     net::{IpAddr, Ipv6Addr},
-    sync::Arc,
+    sync::{Arc, Mutex},
     time::Duration,
 };
-use tokio::sync::mpsc;
+use tokio::sync::mpsc::{self, Sender};
 
 #[ctor]
 fn setup_logging() {
-    env_logger::init();
+    use env_logger::Env;
+    let env = Env::default().default_filter_or("ERROR");
+    env_logger::init_from_env(env);
 }
 
 lazy_static! {
@@ -38,304 +40,252 @@ lazy_static! {
     static ref ROOT_PRINCIPAL: Principal = ROOT_CREDS.principal();
 }
 
-#[derive(Serialize, Deserialize)]
-enum MsgError {
-    Unknown,
+#[derive(Debug, Serialize, Deserialize)]
+enum Reply {
+    Success,
+    Fail,
+    ReadReply { offset: u64, buf: Vec<u8> },
 }
 
-#[derive(Deserialize)]
-enum BodyOwned {
+#[derive(Serialize, Deserialize)]
+enum Msg<'a> {
     Ping,
     Success,
-    Fail(MsgError),
+    Fail,
     Read { offset: u64, size: u64 },
-    Write { offset: u64, buf: Vec<u8> },
+    Write { offset: u64, buf: &'a [u8] },
 }
 
-impl CallRx for BodyOwned {
-    type Reply<'a> = BodyRef<'a>;
+impl<'a> CallMsg<'a> for Msg<'a> {
+    type Reply = Reply;
 }
 
-impl SendRx for BodyOwned {}
+impl<'a> SendMsg<'a> for Msg<'a> {}
 
-#[derive(Serialize)]
-enum BodyRef<'a> {
-    Ping,
-    Success,
-    Fail(MsgError),
-    Read { offset: u64, size: u64 },
-    Write { offset: u64, buf: &'a [u8] },
+trait TestFunc<S: 'static + Send, Fut: Send + Future>:
+    Send + Sync + Fn(MsgReceived<Msg<'_>>, Sender<S>) -> Fut
+{
 }
 
-impl<'a> CallTx for BodyRef<'a> {
-    type Reply = BodyOwned;
+impl<
+        S: 'static + Send,
+        Fut: Send + Future,
+        T: Send + Sync + Fn(MsgReceived<Msg<'_>>, Sender<S>) -> Fut,
+    > TestFunc<S, Fut> for T
+{
 }
 
-impl<'a> SendTx for BodyRef<'a> {}
-
-struct TestCase;
+struct Delegate<S, Fut> {
+    func: Arc<dyn TestFunc<S, Fut>>,
+    sender: Sender<S>,
+}
 
-impl TestCase {
-    fn new() -> TestCase {
-        Self
+impl<S, Fut> Clone for Delegate<S, Fut> {
+    fn clone(&self) -> Self {
+        Self {
+            func: self.func.clone(),
+            sender: self.sender.clone(),
+        }
     }
+}
 
-    fn new_process_router(&self) -> (impl Router, Arc<BlockAddr>) {
-        let ip_addr = IpAddr::V6(Ipv6Addr::LOCALHOST);
-        let mut creds = ConcreteCreds::generate().unwrap();
-        let writecap = NODE_CREDS
-            .issue_writecap(
-                creds.principal(),
-                vec![],
-                Epoch::now() + Duration::from_secs(3600),
-            )
-            .unwrap();
-        let addr = Arc::new(BlockAddr::new(ip_addr, Arc::new(writecap.bind_path())));
-        creds.set_writecap(writecap);
-        (router(ip_addr, Arc::new(creds)).unwrap(), addr)
+impl<S: 'static + Send, Fut: Send + Future> Delegate<S, Fut> {
+    fn new<F: 'static + TestFunc<S, Fut>>(sender: Sender<S>, func: F) -> Self {
+        Self {
+            func: Arc::new(func),
+            sender,
+        }
     }
+}
 
-    /// Returns a ([Sender], [Receiver]) pair for a process identified by the given integer.
-    async fn new_process(&self) -> (impl Transmitter, impl Receiver<BodyOwned>) {
-        let (router, addr) = self.new_process_router();
-        let receiver = router.receiver::<BodyOwned>().await.unwrap();
-        let sender = router.transmitter(addr).await.unwrap();
-        (sender, receiver)
+impl<S: 'static + Send, Fut: Send + Future> MsgCallback for Delegate<S, Fut> {
+    type Arg<'de> = Msg<'de> where Self: 'de;
+    type Return = Fut::Output;
+    type CallFut<'s> = Fut where Fut: 's;
+    fn call<'de>(&'de self, arg: MsgReceived<Self::Arg<'de>>) -> Self::CallFut<'de> {
+        (self.func)(arg, self.sender.clone())
     }
 }
 
-#[tokio::test]
-async fn message_received_is_message_sent() {
-    let case = TestCase::new();
-    let (mut sender, mut receiver) = case.new_process().await;
+fn proc_creds() -> impl Creds {
+    let mut creds = ConcreteCreds::generate().unwrap();
+    let writecap = NODE_CREDS
+        .issue_writecap(
+            creds.principal(),
+            vec![],
+            Epoch::now() + Duration::from_secs(3600),
+        )
+        .unwrap();
+    creds.set_writecap(writecap);
+    creds
+}
 
-    sender.send(BodyRef::Ping).await.unwrap();
-    let actual = receiver.next().await.unwrap().unwrap();
+fn proc_rx<F: 'static + MsgCallback>(callback: F) -> (impl Receiver, Arc<BlockAddr>) {
+    let ip_addr = IpAddr::V6(Ipv6Addr::LOCALHOST);
+    let creds = proc_creds();
+    let writecap = creds.writecap().unwrap();
+    let addr = Arc::new(BlockAddr::new(ip_addr, Arc::new(writecap.bind_path())));
+    (receiver(ip_addr, Arc::new(creds), callback).unwrap(), addr)
+}
 
-    let matched = if let BodyOwned::Ping = actual.body() {
-        true
-    } else {
-        false
-    };
-    assert!(matched);
+async fn proc_tx_rx<F: 'static + MsgCallback>(func: F) -> (impl Transmitter, impl Receiver) {
+    let (receiver, addr) = proc_rx(func);
+    let sender = receiver.transmitter(addr).await.unwrap();
+    (sender, receiver)
 }
 
-#[tokio::test]
-async fn message_received_from_path_is_correct() {
-    let case = TestCase::new();
-    let (mut sender, mut receiver) = case.new_process().await;
+async fn file_server() -> (impl Transmitter, impl Receiver) {
+    let (sender, _) = mpsc::channel::<()>(1);
+    let file = Arc::new(Mutex::new([0, 1, 2, 3, 4, 5, 6, 7]));
+    proc_tx_rx(Delegate::new(
+        sender,
+        move |mut received: MsgReceived<Msg<'_>>, _| {
+            let mut guard = file.lock().unwrap();
+            let reply_body = match received.body() {
+                Msg::Read { offset, size } => {
+                    let offset: usize = (*offset).try_into().unwrap();
+                    let size: usize = (*size).try_into().unwrap();
+                    let end: usize = offset + size;
+                    let mut buf = Vec::with_capacity(end - offset);
+                    buf.extend_from_slice(&guard[offset..end]);
+                    Reply::ReadReply {
+                        offset: offset as u64,
+                        buf,
+                    }
+                }
+                Msg::Write { offset, ref buf } => {
+                    let offset: usize = (*offset).try_into().unwrap();
+                    let end: usize = offset + buf.len();
+                    (&mut guard[offset..end]).copy_from_slice(buf);
+                    Reply::Success
+                }
+                _ => Reply::Fail,
+            };
+            let replier = received.take_replier().unwrap();
+            async move { replier.reply(reply_body).await }
+        },
+    ))
+    .await
+}
 
-    sender.send(BodyRef::Ping).await.unwrap();
-    let actual = receiver.next().await.unwrap().unwrap();
+async fn timeout<F: Future>(future: F) -> F::Output {
+    tokio::time::timeout(Duration::from_millis(1000), future)
+        .await
+        .unwrap()
+}
 
-    assert_eq!(receiver.addr().path(), actual.from().as_ref());
+macro_rules! recv {
+    ($rx:expr) => {
+        timeout($rx.recv()).await.unwrap()
+    };
 }
 
 #[tokio::test]
-async fn ping_pong() {
-    let case = TestCase::new();
-    let (mut sender_one, mut receiver_one) = case.new_process().await;
-    let (mut sender_two, mut receiver_two) = case.new_process().await;
+async fn message_received_is_message_sent() {
+    let (sender, mut passed) = mpsc::channel(1);
+    let (mut sender, _receiver) = proc_tx_rx(Delegate::new(
+        sender,
+        |msg: MsgReceived<Msg<'_>>, sender: Sender<bool>| {
+            let passed = if let Msg::Ping = msg.body() {
+                true
+            } else {
+                false
+            };
+            let sender = sender.clone();
+            async move {
+                sender.send(passed).await.unwrap();
+            }
+        },
+    ))
+    .await;
 
-    tokio::spawn(async move {
-        let received = receiver_one.next().await.unwrap().unwrap();
-        let reply_body = if let BodyOwned::Ping = received.body() {
-            BodyRef::Success
-        } else {
-            BodyRef::Fail(MsgError::Unknown)
-        };
-        let fut = assert_send::<'_, Result<()>>(sender_two.send(reply_body));
-        fut.await.unwrap();
-        sender_two.finish().await.unwrap();
-    });
+    sender.send(Msg::Ping).await.unwrap();
 
-    sender_one.send(BodyRef::Ping).await.unwrap();
-    let reply = receiver_two.next().await.unwrap().unwrap();
-    let matched = if let BodyOwned::Success = reply.body() {
-        true
-    } else {
-        false
-    };
-    assert!(matched);
-    assert_eq!(receiver_two.addr().path(), reply.from().as_ref());
+    assert!(recv!(passed));
 }
 
 #[tokio::test]
-async fn read_write() {
-    let case = TestCase::new();
-    let (mut sender_one, mut receiver_one) = case.new_process().await;
-    let (mut sender_two, mut receiver_two) = case.new_process().await;
-
-    let handle = tokio::spawn(async move {
-        let data: [u8; 8] = [0, 1, 2, 3, 4, 5, 6, 7];
-        let received = receiver_one.next().await.unwrap().unwrap();
-        let reply_body = if let BodyOwned::Read { offset, size } = received.body() {
-            let offset: usize = (*offset).try_into().unwrap();
-            let size: usize = (*size).try_into().unwrap();
-            let end: usize = offset + size;
-            BodyRef::Write {
-                offset: offset as u64,
-                buf: &data[offset..end],
+async fn message_received_from_path_is_correct() {
+    let (sender, mut path) = mpsc::channel(1);
+    let (mut sender, receiver) = proc_tx_rx(Delegate::new(
+        sender,
+        |msg: MsgReceived<Msg<'_>>, sender: Sender<Arc<BlockPath>>| {
+            let path = msg.from().clone();
+            let sender = sender.clone();
+            async move {
+                sender.send(path).await.unwrap();
             }
-        } else {
-            BodyRef::Fail(MsgError::Unknown)
-        };
-        let fut = assert_send::<'_, Result<()>>(sender_two.send(reply_body));
-        fut.await.unwrap();
-        sender_two.finish().await.unwrap();
-    });
+        },
+    ))
+    .await;
 
-    sender_one
-        .send(BodyRef::Read { offset: 2, size: 2 })
-        .await
-        .unwrap();
-    handle.await.unwrap();
-    let reply = receiver_two.next().await.unwrap().unwrap();
-    if let BodyOwned::Write { offset, buf } = reply.body() {
-        assert_eq!(2, *offset);
-        assert_eq!([2, 3].as_slice(), buf.as_slice());
-    } else {
-        panic!("reply was not the right type");
-    };
-}
+    sender.send(Msg::Ping).await.unwrap();
 
-async fn file_server<T: Receiver<BodyOwned> + Unpin>(
-    mut receiver: T,
-    mut stop_rx: mpsc::Receiver<()>,
-) {
-    let mut file: [u8; 8] = [0, 1, 2, 3, 4, 5, 6, 7];
-    loop {
-        let mut received = tokio::select! {
-            Some(..) = stop_rx.recv() => return,
-            Some(received) = receiver.next() => received.unwrap(),
-        };
-        let reply_body = match received.body() {
-            BodyOwned::Read { offset, size } => {
-                let offset: usize = (*offset).try_into().unwrap();
-                let size: usize = (*size).try_into().unwrap();
-                let end: usize = offset + size;
-                BodyRef::Write {
-                    offset: offset as u64,
-                    buf: &file[offset..end],
-                }
-            }
-            BodyOwned::Write { offset, ref buf } => {
-                let offset: usize = (*offset).try_into().unwrap();
-                (&mut file[offset..buf.len()]).copy_from_slice(buf);
-                BodyRef::Success
-            }
-            _ => BodyRef::Fail(MsgError::Unknown),
-        };
-        received.reply(reply_body).await.unwrap();
-    }
+    assert_eq!(receiver.addr().path(), recv!(path).as_ref());
 }
 
 #[tokio::test]
 async fn reply_to_read() {
-    let case = TestCase::new();
-    let (mut sender, receiver) = case.new_process().await;
-    let (stop_tx, stop_rx) = mpsc::channel::<()>(1);
-
-    let handle = tokio::spawn(file_server(receiver, stop_rx));
-
+    let (mut sender, _receiver) = file_server().await;
     let reply = sender
-        .call(BodyRef::Read { offset: 2, size: 2 })
+        .call(Msg::Read { offset: 2, size: 2 })
         .await
         .unwrap();
-    if let BodyOwned::Write { offset, buf } = reply {
+    if let Reply::ReadReply { offset, buf } = reply {
         assert_eq!(2, offset);
         assert_eq!([2, 3].as_slice(), buf.as_slice());
     } else {
         panic!("reply was not the right type");
     };
-
-    stop_tx.send(()).await.unwrap();
-    handle.await.unwrap();
 }
 
 #[tokio::test]
 async fn call_twice() {
-    let case = TestCase::new();
-    let (mut sender, receiver) = case.new_process().await;
-    let (stop_tx, stop_rx) = mpsc::channel::<()>(1);
-
-    let handle = tokio::spawn(file_server(receiver, stop_rx));
+    let (mut sender, _receiver) = file_server().await;
 
     let reply = sender
-        .call(BodyRef::Read { offset: 2, size: 2 })
+        .call(Msg::Write {
+            offset: 1,
+            buf: &[1, 1],
+        })
         .await
         .unwrap();
-    if let BodyOwned::Write { offset, buf } = reply {
-        assert_eq!(2, offset);
-        assert_eq!([2, 3].as_slice(), buf.as_slice());
+    if let Reply::Success = reply {
+        ()
     } else {
         panic!("reply was not the right type");
     };
     let reply = sender
-        .call(BodyRef::Read { offset: 3, size: 5 })
+        .call(Msg::Read { offset: 1, size: 2 })
         .await
         .unwrap();
-    if let BodyOwned::Write { offset, buf } = reply {
-        assert_eq!(3, offset);
-        assert_eq!([3, 4, 5, 6, 7].as_slice(), buf.as_slice());
+    if let Reply::ReadReply { offset, buf } = reply {
+        assert_eq!(1, offset);
+        assert_eq!([1, 1].as_slice(), buf.as_slice());
     } else {
         panic!("second reply was not the right type");
     }
-
-    stop_tx.send(()).await.unwrap();
-    handle.await.unwrap();
 }
 
 #[tokio::test]
-async fn replies_sent_out_of_order() {
-    let case = TestCase::new();
-    let (sender_one, mut receiver_one) = case.new_process().await;
-    let (router, ..) = case.new_process_router();
-    let sender_two = router
-        .transmitter(Arc::new(sender_one.addr().clone()))
+async fn separate_transmitter() {
+    let (_senderx, receiver) = file_server().await;
+    let creds = proc_creds();
+    let mut transmitter = transmitter(receiver.addr().clone(), Arc::new(creds))
         .await
         .unwrap();
 
-    let handle = tokio::spawn(async move {
-        const EMPTY_SLICE: &[u8] = &[];
-        fn reply(body: &BodyOwned) -> BodyRef<'static> {
-            match body {
-                BodyOwned::Write { offset, .. } => BodyRef::Write {
-                    offset: *offset,
-                    buf: EMPTY_SLICE,
-                },
-                _ => panic!("message was the wrong variant"),
-            }
-        }
-        let mut received_one = receiver_one.next().await.unwrap().unwrap();
-        let mut received_two = receiver_one.next().await.unwrap().unwrap();
-        received_two
-            .reply(reply(received_two.body()))
-            .await
-            .unwrap();
-        received_one
-            .reply(reply(received_one.body()))
-            .await
-            .unwrap();
-    });
-
-    async fn client(num: u64, mut tx: impl Transmitter) {
-        let fut = assert_send::<'_, Result<BodyOwned>>(tx.call(BodyRef::Write {
-            offset: num,
-            buf: [].as_slice(),
-        }));
-        let reply = fut.await.unwrap();
-        if let BodyOwned::Write { offset, .. } = reply {
-            assert_eq!(num, offset);
-        } else {
-            panic!("reply was the wrong variant");
-        }
-    }
-
-    let handle_one = tokio::spawn(client(1, sender_one));
-    let handle_two = tokio::spawn(client(2, sender_two));
-
-    handle.await.unwrap();
-    handle_one.await.unwrap();
-    handle_two.await.unwrap();
+    let reply = transmitter
+        .call(Msg::Write {
+            offset: 5,
+            buf: &[7, 7, 7],
+        })
+        .await
+        .unwrap();
+    let matched = if let Reply::Success = reply {
+        true
+    } else {
+        false
+    };
+    assert!(matched);
 }

+ 86 - 45
crates/btserde/src/de.rs

@@ -1,22 +1,24 @@
-use super::error::{Error, Result};
+use crate::{
+    error::{Error, Result},
+    reader::{ReadAdapter, Reader, SliceAdapter},
+};
 use serde::de::{self, Deserialize, DeserializeOwned, DeserializeSeed, IntoDeserializer, Visitor};
 use std::convert::TryFrom;
 use std::io::Read;
 use std::str;
 
-// This lint is disabled because deserializing from a `&[u8]` is handled by `read_from`.
-pub fn from_vec<T: DeserializeOwned>(vec: &Vec<u8>) -> Result<T> {
-    let mut slice = vec.as_slice();
-    read_from(&mut slice)
+pub fn from_vec<'de, T: Deserialize<'de>>(vec: &'de Vec<u8>) -> Result<T> {
+    from_slice(vec.as_slice())
 }
 
-pub fn read_from<T: for<'de> Deserialize<'de>, R: Read>(read: &mut R) -> Result<T> {
-    let mut de = Deserializer::new(read);
+pub fn read_from<T: DeserializeOwned, R: Read>(read: &mut R) -> Result<T> {
+    let mut de = Deserializer::new(ReadAdapter::new(read));
     Deserialize::deserialize(&mut de)
 }
 
-pub struct Deserializer<'de, T: Read + ?Sized> {
-    input: &'de mut T,
+pub fn from_slice<'de, T: Deserialize<'de>>(slice: &'de [u8]) -> Result<T> {
+    let mut de = Deserializer::new(SliceAdapter::new(slice));
+    Deserialize::deserialize(&mut de)
 }
 
 fn try_from<TSource, TDest: TryFrom<TSource>>(value: TSource) -> Result<TDest> {
@@ -37,7 +39,7 @@ fn num_bytes_for_char(first_byte: u8) -> Result<usize> {
             return Ok(bit);
         }
     }
-    Err(Error::InvalidUtf8Char)
+    Err(Error::InvalidUtf8)
 }
 
 /// Returns the unicode code point of the character that is encoded in UTF-8 in the given buffer
@@ -72,23 +74,30 @@ fn u32_from_utf8(buf: &[u8]) -> Result<u32> {
         4 => Some(four_bytes(buf)),
         _ => None,
     };
-    code_point.ok_or(Error::InvalidUtf8Char)
+    code_point.ok_or(Error::InvalidUtf8)
 }
 
 fn char_from_utf8(buf: &[u8]) -> Result<char> {
     let result = u32_from_utf8(buf);
     let option = char::from_u32(result?);
-    option.ok_or(Error::InvalidUtf8Char)
+    option.ok_or(Error::InvalidUtf8)
+}
+
+pub struct Deserializer<T> {
+    input: T,
 }
 
-impl<'de, T: Read + ?Sized> Deserializer<'de, T> {
-    pub fn new(input: &'de mut T) -> Self {
-        Deserializer { input }
+impl<'de, T: Reader<'de>> Deserializer<T> {
+    pub fn new(input: T) -> Self {
+        Self { input }
     }
 
     fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> {
-        self.input.read_exact(buf).map_err(Error::Io)?;
-        Ok(())
+        self.input.read_exact(buf)
+    }
+
+    fn borrow_bytes(&mut self, len: usize) -> Result<&'de [u8]> {
+        self.input.borrow_bytes(len)
     }
 
     fn read_array<const N: usize>(&mut self) -> Result<[u8; N]> {
@@ -159,20 +168,28 @@ impl<'de, T: Read + ?Sized> Deserializer<'de, T> {
         Ok(vec)
     }
 
+    fn read_bytes(&mut self) -> Result<&'de [u8]> {
+        let len = try_from(self.read_u32()?)?;
+        self.borrow_bytes(len)
+    }
+
     fn read_string(&mut self) -> Result<String> {
         let vec = self.read_vec()?;
         let value = String::from_utf8(vec).map_err(|_| Error::TypeConversion)?;
         Ok(value)
     }
+
+    fn read_str(&mut self) -> Result<&'de str> {
+        let bytes = self.read_bytes()?;
+        std::str::from_utf8(bytes).map_err(|_| Error::InvalidUtf8)
+    }
 }
 
-impl<'de, 'a, T: Read> de::Deserializer<'de> for &'a mut Deserializer<'de, T> {
+impl<'de, 'a, T: Reader<'de>> de::Deserializer<'de> for &'a mut Deserializer<T> {
     type Error = Error;
 
     fn deserialize_any<V: Visitor<'de>>(self, _visitor: V) -> Result<V::Value> {
-        Err(Error::Message(
-            "deserialize_any is not supported".to_string(),
-        ))
+        Err(Error::NotSupported("deserialize_any is not supported"))
     }
 
     fn deserialize_bool<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
@@ -246,17 +263,17 @@ impl<'de, 'a, T: Read> de::Deserializer<'de> for &'a mut Deserializer<'de, T> {
 
     fn deserialize_char<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
         let byte = self.read_u8()?;
-        let buf_len = num_bytes_for_char(byte);
-        let mut buf = vec![0; buf_len?];
+        let char_len = num_bytes_for_char(byte)?;
+        let mut buf = [0u8; 4];
         buf[0] = byte;
-        self.read_exact(&mut buf[1..])?;
-        let result = char_from_utf8(&buf);
+        self.read_exact(&mut buf[1..char_len])?;
+        let result = char_from_utf8(&buf[..char_len]);
         visitor.visit_char(result?)
     }
 
     fn deserialize_str<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
-        let value = self.read_string()?;
-        visitor.visit_str(value.as_str())
+        let value = self.read_str()?;
+        visitor.visit_borrowed_str(value)
     }
 
     fn deserialize_string<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
@@ -265,8 +282,8 @@ impl<'de, 'a, T: Read> de::Deserializer<'de> for &'a mut Deserializer<'de, T> {
     }
 
     fn deserialize_bytes<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
-        let value = self.read_vec()?;
-        visitor.visit_bytes(value.as_slice())
+        let value = self.read_bytes()?;
+        visitor.visit_borrowed_bytes(value)
     }
 
     fn deserialize_byte_buf<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
@@ -366,12 +383,12 @@ impl<'de, 'a, T: Read> de::Deserializer<'de> for &'a mut Deserializer<'de, T> {
     }
 }
 
-struct SeqAccess<'a, 'de, T: Read> {
+struct SeqAccess<'a, T> {
     elements_left: u32,
-    deserializer: &'a mut Deserializer<'de, T>,
+    deserializer: &'a mut Deserializer<T>,
 }
 
-impl<'a, 'de, T: Read> de::SeqAccess<'de> for SeqAccess<'a, 'de, T> {
+impl<'a, 'de, T: Reader<'de>> de::SeqAccess<'de> for SeqAccess<'a, T> {
     type Error = Error;
 
     fn next_element_seed<S: DeserializeSeed<'de>>(&mut self, seed: S) -> Result<Option<S::Value>> {
@@ -384,7 +401,7 @@ impl<'a, 'de, T: Read> de::SeqAccess<'de> for SeqAccess<'a, 'de, T> {
     }
 }
 
-impl<'a, 'de, T: Read> de::MapAccess<'de> for SeqAccess<'a, 'de, T> {
+impl<'a, 'de, T: Reader<'de>> de::MapAccess<'de> for SeqAccess<'a, T> {
     type Error = Error;
 
     fn next_key_seed<S: DeserializeSeed<'de>>(&mut self, seed: S) -> Result<Option<S::Value>> {
@@ -401,7 +418,7 @@ impl<'a, 'de, T: Read> de::MapAccess<'de> for SeqAccess<'a, 'de, T> {
     }
 }
 
-impl<'de, T: Read> de::VariantAccess<'de> for &mut Deserializer<'de, T> {
+impl<'de, T: Reader<'de>> de::VariantAccess<'de> for &mut Deserializer<T> {
     type Error = Error;
 
     fn unit_variant(self) -> Result<()> {
@@ -425,7 +442,7 @@ impl<'de, T: Read> de::VariantAccess<'de> for &mut Deserializer<'de, T> {
     }
 }
 
-impl<'de, T: Read> de::EnumAccess<'de> for &mut Deserializer<'de, T> {
+impl<'de, T: Reader<'de>> de::EnumAccess<'de> for &mut Deserializer<T> {
     type Error = Error;
     type Variant = Self;
 
@@ -439,19 +456,12 @@ impl<'de, T: Read> de::EnumAccess<'de> for &mut Deserializer<'de, T> {
 
 #[cfg(test)]
 mod test {
-    use super::{from_vec, num_bytes_for_char, Deserializer, Result};
+    use crate::from_slice;
+
+    use super::{from_vec, num_bytes_for_char, Result};
     use serde::Deserialize;
     use std::collections::HashMap;
 
-    #[test]
-    fn new() -> Result<()> {
-        let vec: Vec<u8> = vec![0xA1, 0x42, 0x71, 0xAC];
-        let mut slice = vec.as_slice();
-        let de = Deserializer::new(&mut slice);
-        assert_eq!(&vec.as_slice(), de.input);
-        Ok(())
-    }
-
     #[test]
     fn test_num_bytes_for_char() -> Result<()> {
         fn test_case(c: char) -> Result<()> {
@@ -603,6 +613,28 @@ mod test {
         Ok(())
     }
 
+    /// Returns the bytes that would be returned if the given data were serialized.
+    macro_rules! input {
+        ($expected:expr) => {{
+            const LEN: [u8; 4] = ($expected.len() as u32).to_le_bytes();
+            const INPUT_LEN: usize = LEN.len() + EXPECTED.len();
+            let mut input = [0u8; INPUT_LEN];
+            (&mut input[..LEN.len()]).copy_from_slice(&LEN);
+            (&mut input[LEN.len()..]).copy_from_slice(&EXPECTED);
+            input
+        }};
+    }
+
+    #[test]
+    fn deserialize_str() -> Result<()> {
+        const EXPECTED: &[u8] = b"Without nature's limits, humanity's true nature was revealed.";
+        let input = input!(EXPECTED);
+        let actual: &str = from_slice(input.as_slice())?;
+        let expected = std::str::from_utf8(EXPECTED).unwrap();
+        assert_eq!(expected, actual);
+        Ok(())
+    }
+
     #[test]
     fn deserialize_string() -> Result<()> {
         let vec: Vec<u8> = vec![
@@ -614,6 +646,15 @@ mod test {
         Ok(())
     }
 
+    #[test]
+    fn deserialize_bytes() -> Result<()> {
+        const EXPECTED: &[u8] = b"I have altered the API, pray I don't alter it any further.";
+        let input = input!(EXPECTED);
+        let actual: &[u8] = from_slice(input.as_slice())?;
+        assert_eq!(EXPECTED, actual);
+        Ok(())
+    }
+
     #[test]
     fn deserialize_byte_buf() -> Result<()> {
         let vec: Vec<u8> = vec![0x04, 0x00, 0x00, 0x00, 0x42, 0x91, 0xBE, 0xEF];

+ 6 - 8
crates/btserde/src/error.rs

@@ -16,7 +16,7 @@ pub enum Error {
     TooManyVariants(u32),
     TypeConversion,
     NotSupported(&'static str),
-    InvalidUtf8Char,
+    InvalidUtf8,
     Format(std::fmt::Error),
     Custom(Box<dyn std::error::Error + Send + Sync>),
 }
@@ -30,19 +30,17 @@ impl Display for Error {
             Error::Io(io_error) => io_error.fmt(formatter),
             Error::Eof => formatter.write_str("unexpected end of input"),
             Error::UnknownLength => formatter.write_str("sequence had an unknown length"),
-            Error::SequenceTooLong(length) => formatter.write_fmt(format_args!(
-                "sequence was longer than 2**32 - 1: {}",
-                length
-            )),
+            Error::SequenceTooLong(length) => {
+                formatter.write_fmt(format_args!("sequence was longer than 2**32 - 1: {length}",))
+            }
             Error::TooManyVariants(length) => formatter.write_fmt(format_args!(
-                "too many variants to be serialized, the limit is 2**16: {}",
-                length
+                "too many variants to be serialized, the limit is 2**16: {length}",
             )),
             Error::TypeConversion => formatter.write_str("type conversion failed"),
             Error::NotSupported(message) => {
                 formatter.write_fmt(format_args!("Operation is not supported: {message}"))
             }
-            Error::InvalidUtf8Char => formatter.write_str("Invalid UTF-8 character encountered."),
+            Error::InvalidUtf8 => formatter.write_str("Invalid UTF-8 character encountered."),
             Error::Format(fmt_error) => fmt_error.fmt(formatter),
             Error::Custom(err) => err.fmt(formatter),
         }

+ 4 - 2
crates/btserde/src/lib.rs

@@ -1,6 +1,8 @@
+//! This crate defines a compact binary serialization format for use in the Block Tree system.
+
 mod de;
-/// This crate defines a compact binary serialization format for use in the Block Tree system.
 mod error;
+mod reader;
 mod ser;
 
 #[cfg(test)]
@@ -10,4 +12,4 @@ pub use error::{Error, Result};
 
 pub use ser::{to_vec, write_to, Serializer};
 
-pub use de::{from_vec, read_from, Deserializer};
+pub use de::{from_slice, from_vec, read_from, Deserializer};

+ 91 - 0
crates/btserde/src/reader.rs

@@ -0,0 +1,91 @@
+//! This module contains the [Reader] trait which enables zero-copy deserialization over
+//! different input types.
+use crate::{error::MapError, Error, Result};
+
+use std::io::Read;
+
+pub trait Reader<'de> {
+    /// Reads exactly the enough bytes to fill the given buffer or returns an error.
+    fn read_exact(&mut self, buf: &mut [u8]) -> Result<()>;
+    /// Returns true if the `borrow_bytes` method is supported by this instance.
+    fn can_borrow() -> bool;
+    /// Borrows the given number of bytes from this [Reader] starting at the current position.
+    fn borrow_bytes(&mut self, len: usize) -> Result<&'de [u8]>;
+}
+
+impl<'de, R: Reader<'de>> Reader<'de> for &mut R {
+    fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> {
+        (*self).read_exact(buf)
+    }
+
+    fn can_borrow() -> bool {
+        R::can_borrow()
+    }
+
+    fn borrow_bytes(&mut self, len: usize) -> Result<&'de [u8]> {
+        (*self).borrow_bytes(len)
+    }
+}
+
+/// An adapter over an implementation of the standard library [Read] trait.
+pub struct ReadAdapter<R>(R);
+
+impl<R> ReadAdapter<R> {
+    pub fn new(read: R) -> Self {
+        Self(read)
+    }
+}
+
+impl<'de, R: Read> Reader<'de> for ReadAdapter<R> {
+    fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> {
+        self.0.read_exact(buf).map_error()
+    }
+
+    fn can_borrow() -> bool {
+        false
+    }
+
+    fn borrow_bytes(&mut self, _len: usize) -> Result<&'de [u8]> {
+        Err(Error::NotSupported(
+            "borrowing from a ReadAdapter is not supported",
+        ))
+    }
+}
+
+/// An adapter for reading out of a slice of bytes.
+pub struct SliceAdapter<'a>(&'a [u8]);
+
+impl<'a> SliceAdapter<'a> {
+    pub fn new(slice: &'a [u8]) -> Self {
+        Self(slice)
+    }
+
+    fn assert_longer_than(&self, buf_len: usize) -> Result<()> {
+        if self.0.len() >= buf_len {
+            Ok(())
+        } else {
+            Err(Error::Eof)
+        }
+    }
+}
+
+impl<'de> Reader<'de> for SliceAdapter<'de> {
+    fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> {
+        let buf_len = buf.len();
+        self.assert_longer_than(buf_len)?;
+        buf.copy_from_slice(&self.0[..buf_len]);
+        self.0 = &self.0[buf_len..];
+        Ok(())
+    }
+
+    fn can_borrow() -> bool {
+        true
+    }
+
+    fn borrow_bytes(&mut self, len: usize) -> Result<&'de [u8]> {
+        self.assert_longer_than(len)?;
+        let borrow = &self.0[..len];
+        self.0 = &self.0[len..];
+        Ok(borrow)
+    }
+}