Browse Source

* Added a crate for defining the file server inteface.
* Started modifying btmsg to allow multiple calls concurrently.

Matthew Carr 2 years ago
parent
commit
62121ca7c1

+ 21 - 2
Cargo.lock

@@ -199,6 +199,18 @@ dependencies = [
  "alloc-stdlib",
 ]
 
+[[package]]
+name = "btfproto"
+version = "0.1.0"
+dependencies = [
+ "btlib",
+ "btmsg",
+ "log",
+ "paste",
+ "serde",
+ "tokio",
+]
+
 [[package]]
 name = "btfs"
 version = "0.1.0"
@@ -268,6 +280,7 @@ dependencies = [
  "btlib",
  "btserde",
  "bytes",
+ "chrono",
  "ctor",
  "env_logger",
  "futures",
@@ -1234,6 +1247,12 @@ version = "6.4.1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "9b7820b9daea5457c9f21c69448905d723fbd21136ccf521748f23fd49e723ee"
 
+[[package]]
+name = "paste"
+version = "1.0.11"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d01a5bd0424d00070b0098dd17ebca6f961a959dead1dbcbbbc1d1cd8d3deeba"
+
 [[package]]
 name = "peeking_take_while"
 version = "0.1.2"
@@ -1992,9 +2011,9 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c"
 
 [[package]]
 name = "tokio"
-version = "1.23.0"
+version = "1.24.2"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "eab6d665857cc6ca78d6e80303a02cea7a7851e85dfbd77cbdc09bd129f1ef46"
+checksum = "597a12a59981d9e3c38d216785b0c37399f6e415e8d0712047620f189371b0bb"
 dependencies = [
  "autocfg",
  "bytes",

+ 1 - 0
Cargo.toml

@@ -9,6 +9,7 @@ members = [
     "crates/btfuse",
     "crates/swtpm-harness",
     "crates/btfs",
+    "crates/btfproto",
 ]
 
 [profile.bench]

+ 19 - 0
crates/btfproto/Cargo.toml

@@ -0,0 +1,19 @@
+[package]
+name = "btfproto"
+version = "0.1.0"
+edition = "2021"
+
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[features]
+server = []
+client = []
+default = ["client", "server"]
+
+[dependencies]
+btlib = { path = "../btlib" }
+btmsg = { path = "../btmsg" }
+serde = { version = "^1.0.136", features = ["derive"] }
+paste = "1.0.11"
+log = "0.4.17"
+tokio = { version = "1.24.2", features = ["rt"] }

+ 1 - 0
crates/btfproto/src/client.rs

@@ -0,0 +1 @@
+

+ 7 - 0
crates/btfproto/src/lib.rs

@@ -0,0 +1,7 @@
+#![feature(type_alias_impl_trait)]
+
+#[cfg(feature = "client")]
+mod client;
+mod msg;
+#[cfg(feature = "server")]
+pub mod server;

+ 205 - 0
crates/btfproto/src/msg.rs

@@ -0,0 +1,205 @@
+use btlib::BlockMeta;
+use btmsg::CallMsg;
+use core::time::Duration;
+use serde::{Deserialize, Serialize};
+use std::fmt::Display;
+
+pub type Inode = u64;
+pub type Handle = u64;
+
+#[derive(Serialize, Deserialize)]
+pub enum FsMsg<'a> {
+    #[serde(borrow)]
+    Lookup(Lookup<'a>),
+    #[serde(borrow)]
+    Create(Create<'a>),
+    Open(Open),
+    Read(Read),
+    #[serde(borrow)]
+    Write(Write<'a>),
+    Flush(Flush),
+    #[serde(borrow)]
+    Link(Link<'a>),
+    #[serde(borrow)]
+    Unlink(Unlink<'a>),
+    ReadMeta(ReadMeta),
+    WriteMeta(WriteMeta),
+    Close(Close),
+    Forget(Forget),
+    Lock(Lock),
+    Unlock(Unlock),
+}
+
+#[derive(Serialize, Deserialize)]
+pub enum FsReply<'a> {
+    Ack(()),
+    Lookup(LookupReply),
+    Create(CreateReply),
+    Open(OpenReply),
+    #[serde(borrow)]
+    Read(ReadReply<'a>),
+    Write(WriteReply),
+    Link(LinkReply),
+    ReadMeta(BlockMeta),
+}
+
+impl<'a> CallMsg<'a> for FsMsg<'a> {
+    type Reply<'b> = FsReply<'b>;
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub enum FsError {
+    Other,
+}
+
+impl Display for FsError {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        match self {
+            Self::Other => write!(f, "uncategorized error"),
+        }
+    }
+}
+
+impl std::error::Error for FsError {}
+
+#[derive(Serialize, Deserialize)]
+pub struct Entry {
+    //pub attr: stat64,
+    pub attr_flags: u32,
+    pub attr_timeout: Duration,
+    pub entry_timeout: Duration,
+}
+
+#[derive(Serialize, Deserialize)]
+pub struct Lookup<'a> {
+    pub parent: Inode,
+    pub name: &'a str,
+}
+
+#[derive(Serialize, Deserialize)]
+pub struct LookupReply {
+    pub inode: Inode,
+    pub generation: u64,
+    pub entry: Entry,
+}
+
+#[derive(Serialize, Deserialize)]
+pub struct Create<'a> {
+    pub parent: Inode,
+    pub name: &'a str,
+    pub flags: u32,
+    pub mode: u32,
+    pub umask: u32,
+}
+
+#[derive(Serialize, Deserialize)]
+pub struct CreateReply {
+    pub handle: Handle,
+    pub entry: Entry,
+}
+
+#[derive(Serialize, Deserialize)]
+pub struct Open {
+    pub inode: Inode,
+    pub flags: u32,
+}
+
+#[derive(Serialize, Deserialize)]
+pub struct OpenReply {
+    pub handle: Handle,
+}
+
+#[derive(Serialize, Deserialize)]
+pub struct Read {
+    pub inode: Inode,
+    pub handle: Handle,
+    pub offset: u64,
+    pub size: u64,
+}
+
+#[derive(Serialize, Deserialize)]
+pub struct ReadReply<'a> {
+    pub data: &'a [u8],
+}
+
+#[derive(Serialize, Deserialize)]
+pub struct Write<'a> {
+    pub inode: Inode,
+    pub handle: Handle,
+    pub offset: u64,
+    pub data: &'a [u8],
+}
+
+#[derive(Serialize, Deserialize)]
+pub struct WriteReply {
+    pub written: u64,
+}
+
+#[derive(Serialize, Deserialize)]
+pub struct Flush {
+    pub inode: Inode,
+    pub handle: Handle,
+}
+
+#[derive(Serialize, Deserialize)]
+pub struct Link<'a> {
+    pub inode: Inode,
+    pub new_parent: Inode,
+    pub name: &'a str,
+}
+
+#[derive(Serialize, Deserialize)]
+pub struct LinkReply {
+    pub entry: Entry,
+}
+
+#[derive(Serialize, Deserialize)]
+pub struct Unlink<'a> {
+    pub parent: Inode,
+    pub name: &'a str,
+}
+
+#[derive(Serialize, Deserialize)]
+pub struct ReadMeta {
+    pub inode: Inode,
+    pub handle: Handle,
+}
+
+#[derive(Serialize, Deserialize)]
+pub struct WriteMeta {
+    pub inode: Inode,
+    pub handle: Handle,
+    pub meta: BlockMeta,
+}
+
+#[derive(Serialize, Deserialize)]
+pub struct Close {
+    pub inode: Inode,
+    pub handle: Handle,
+}
+
+#[derive(Serialize, Deserialize)]
+pub struct Forget {
+    pub inode: Inode,
+    pub count: u64,
+}
+
+#[derive(Serialize, Deserialize)]
+pub struct LockDesc {
+    pub offset: u64,
+    pub size: u64,
+    pub exclusive: bool,
+}
+
+#[derive(Serialize, Deserialize)]
+pub struct Lock {
+    pub inode: Inode,
+    pub handle: Handle,
+    pub desc: LockDesc,
+}
+
+#[derive(Serialize, Deserialize)]
+pub struct Unlock {
+    pub inode: Inode,
+    pub handle: Handle,
+}

+ 85 - 0
crates/btfproto/src/server.rs

@@ -0,0 +1,85 @@
+use crate::msg::*;
+
+use btlib::{crypto::Creds, BlockMeta, BlockPath};
+use btmsg::{receiver, MsgCallback, MsgReceived, Receiver};
+use core::future::Future;
+use std::{net::IpAddr, result::Result, sync::Arc};
+
+pub trait FsProvider {
+    fn lookup(&self, from: &BlockPath, msg: &Lookup) -> Result<LookupReply, FsError>;
+    fn create(&self, from: &BlockPath, msg: &Create) -> Result<CreateReply, FsError>;
+    fn open(&self, from: &BlockPath, msg: &Open) -> Result<OpenReply, FsError>;
+    fn read(&self, from: &BlockPath, msg: &Read) -> Result<ReadReply<'_>, FsError>;
+    fn write(&self, from: &BlockPath, msg: &Write) -> Result<WriteReply, FsError>;
+    fn flush(&self, from: &BlockPath, msg: &Flush) -> Result<(), FsError>;
+    fn link(&self, from: &BlockPath, msg: &Link) -> Result<(), FsError>;
+    fn unlink(&self, from: &BlockPath, msg: &Unlink) -> Result<(), FsError>;
+    fn read_meta(&self, from: &BlockPath, msg: &ReadMeta) -> Result<BlockMeta, FsError>;
+    fn write_meta(&self, from: &BlockPath, msg: &WriteMeta) -> Result<(), FsError>;
+    fn close(&self, from: &BlockPath, msg: &Close) -> Result<(), FsError>;
+    fn forget(&self, from: &BlockPath, msg: &Forget) -> Result<(), FsError>;
+    fn lock(&self, from: &BlockPath, msg: &Lock) -> Result<(), FsError>;
+    fn unlock(&self, from: &BlockPath, msg: &Unlock) -> Result<(), FsError>;
+}
+
+struct ServerCallback<P> {
+    provider: Arc<P>,
+}
+
+impl<P> ServerCallback<P> {
+    fn new(provider: Arc<P>) -> Self {
+        Self { provider }
+    }
+}
+
+impl<P> Clone for ServerCallback<P> {
+    fn clone(&self) -> Self {
+        Self {
+            provider: self.provider.clone(),
+        }
+    }
+}
+
+impl<P: 'static + Send + Sync + FsProvider> MsgCallback for ServerCallback<P> {
+    type Arg<'de> = FsMsg<'de>;
+    type CallFut<'de> = impl 'de + Future<Output = btlib::Result<()>>;
+    fn call<'de>(&'de self, mut arg: MsgReceived<FsMsg<'de>>) -> Self::CallFut<'de> {
+        async move {
+            let provider = &self.provider;
+            let reply = match arg.body() {
+                FsMsg::Lookup(lookup) => FsReply::Lookup(provider.lookup(arg.from(), lookup)?),
+                FsMsg::Create(create) => FsReply::Create(provider.create(arg.from(), create)?),
+                FsMsg::Open(open) => FsReply::Open(provider.open(arg.from(), open)?),
+                FsMsg::Read(read) => FsReply::Read(provider.read(arg.from(), read)?),
+                FsMsg::Write(write) => FsReply::Write(provider.write(arg.from(), write)?),
+                FsMsg::Flush(flush) => FsReply::Ack(provider.flush(arg.from(), flush)?),
+                FsMsg::Link(link) => FsReply::Ack(provider.link(arg.from(), link)?),
+                FsMsg::Unlink(unlink) => FsReply::Ack(provider.unlink(arg.from(), unlink)?),
+                FsMsg::ReadMeta(read_meta) => {
+                    FsReply::ReadMeta(provider.read_meta(arg.from(), read_meta)?)
+                }
+                FsMsg::WriteMeta(write_meta) => {
+                    FsReply::Ack(provider.write_meta(arg.from(), write_meta)?)
+                }
+                FsMsg::Close(close) => FsReply::Ack(provider.close(arg.from(), close)?),
+                FsMsg::Forget(forget) => FsReply::Ack(provider.forget(arg.from(), forget)?),
+                FsMsg::Lock(lock) => FsReply::Ack(provider.lock(arg.from(), lock)?),
+                FsMsg::Unlock(unlock) => FsReply::Ack(provider.unlock(arg.from(), unlock)?),
+            };
+            let mut replier = arg.take_replier().unwrap();
+            replier.reply(reply).await
+        }
+    }
+}
+
+pub fn new_fs_server<C, P>(
+    ip_addr: IpAddr,
+    creds: Arc<C>,
+    provider: Arc<P>,
+) -> Result<impl Receiver, btlib::Error>
+where
+    C: 'static + Send + Sync + Creds,
+    P: 'static + Send + Sync + FsProvider,
+{
+    receiver(ip_addr, creds, ServerCallback::new(provider))
+}

+ 1 - 0
crates/btmsg/Cargo.toml

@@ -24,3 +24,4 @@ log = "0.4.17"
 env_logger = "0.9.0"
 ctor = { version = "0.1.22" }
 lazy_static = { version = "1.4.0" }
+chrono = "0.4.23"

+ 30 - 32
crates/btmsg/src/callback_framed.rs

@@ -12,6 +12,8 @@ pub struct CallbackFramed<I> {
 
 impl<I> CallbackFramed<I> {
     const INIT_CAPACITY: usize = 4096;
+    /// The number of bytes used to encode the length of each frame.
+    const FRAME_LEN_SZ: usize = std::mem::size_of::<u64>();
 
     pub fn new(inner: I) -> Self {
         Self {
@@ -20,10 +22,18 @@ impl<I> CallbackFramed<I> {
         }
     }
 
-    async fn decode<'de, F: 'de + DeserCallback>(
-        mut slice: &'de [u8],
-        callback: &'de F,
-    ) -> Result<DecodeStatus<F::Return>> {
+    pub fn into_parts(self) -> (I, BytesMut) {
+        (self.io, self.buffer)
+    }
+
+    pub fn from_parts(io: I, mut buffer: BytesMut) -> Self {
+        if buffer.capacity() < Self::INIT_CAPACITY {
+            buffer.reserve(Self::INIT_CAPACITY - buffer.capacity());
+        }
+        Self { io, buffer }
+    }
+
+    async fn decode(mut slice: &[u8]) -> Result<DecodeStatus> {
         let payload_len: u64 = match read_from(&mut slice) {
             Ok(payload_len) => payload_len,
             Err(err) => {
@@ -41,12 +51,7 @@ impl<I> CallbackFramed<I> {
         if slice.len() < payload_len {
             return Ok(DecodeStatus::Reserve(payload_len - slice.len()));
         }
-        let msg: F::Arg<'de> = from_slice(slice)?;
-        let returned = callback.call(msg).await;
-        Ok(DecodeStatus::Some {
-            returned,
-            consumed: std::mem::size_of::<u64>() + payload_len,
-        })
+        Ok(DecodeStatus::Consume( Self::FRAME_LEN_SZ + payload_len ))
     }
 }
 
@@ -60,7 +65,7 @@ macro_rules! attempt {
 }
 
 impl<S: AsyncRead + Unpin> CallbackFramed<S> {
-    pub async fn next<F: DeserCallback>(&mut self, callback: F) -> Option<Result<F::Return>> {
+    pub async fn next<F: DeserCallback>(&mut self, mut callback: F) -> Option<Result<F::Return>> {
         loop {
             if self.buffer.capacity() - self.buffer.len() == 0 {
                 // If there is no space left in the buffer we reserve additional bytes to ensure
@@ -71,14 +76,16 @@ impl<S: AsyncRead + Unpin> CallbackFramed<S> {
             if 0 == read_ct {
                 return None;
             }
-            match attempt!(Self::decode(&self.buffer[..read_ct], &callback).await) {
+            match attempt!(Self::decode(&self.buffer[..read_ct]).await) {
                 DecodeStatus::None => continue,
                 DecodeStatus::Reserve(count) => {
                     self.buffer.reserve(count);
                     continue;
                 }
-                DecodeStatus::Some { returned, consumed } => {
-                    let _ = self.buffer.split_to(consumed);
+                DecodeStatus::Consume(consume) => {
+                    let start = self.buffer.split_to(consume);
+                    let arg: F::Arg<'_> = attempt!(from_slice(&start[Self::FRAME_LEN_SZ..]));
+                    let returned = callback.call(arg).await;
                     return Some(Ok(returned));
                 }
             }
@@ -86,30 +93,21 @@ impl<S: AsyncRead + Unpin> CallbackFramed<S> {
     }
 }
 
-enum DecodeStatus<R> {
+enum DecodeStatus {
     None,
     Reserve(usize),
-    Some { returned: R, consumed: usize },
+    Consume(usize),
 }
 
-pub trait DeserCallback: Clone {
-    type Arg<'de>: Deserialize<'de> + Send
+pub trait DeserCallback {
+    type Arg<'de>: 'de + Deserialize<'de> + Send
     where
         Self: 'de;
     type Return;
-    type CallFut<'s>: Future<Output = Self::Return> + Send
+    type CallFut<'de>: 'de + Future<Output = Self::Return> + Send
     where
-        Self: 's;
-    fn call<'de>(&'de self, arg: Self::Arg<'de>) -> Self::CallFut<'de>;
-}
-
-impl<F: DeserCallback> DeserCallback for &F {
-    type Arg<'de> = F::Arg<'de> where Self: 'de;
-    type Return = F::Return;
-    type CallFut<'f> = F::CallFut<'f> where Self: 'f;
-    fn call<'de>(&'de self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
-        (*self).call(arg)
-    }
+        Self: 'de;
+    fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de>;
 }
 
 #[cfg(test)]
@@ -140,9 +138,9 @@ mod tests {
         impl DeserCallback for TestCb {
             type Arg<'de> = Msg<'de> where Self: 'de;
             type Return = bool;
-            type CallFut<'f> = Ready<bool>;
+            type CallFut<'de> = Ready<Self::Return>;
 
-            fn call<'de>(&'de self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
+            fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
                 futures::future::ready(arg.0 == test_data!())
             }
         }

+ 199 - 145
crates/btmsg/src/lib.rs

@@ -1,15 +1,21 @@
 //! Code which enables sending messages between processes in the blocktree system.
+#![feature(type_alias_impl_trait)]
 
 mod tls;
 use tls::*;
 mod callback_framed;
-use callback_framed::{CallbackFramed, DeserCallback};
+use callback_framed::CallbackFramed;
+pub use callback_framed::DeserCallback;
 
-use btlib::{bterr, crypto::Creds, error::BoxInIoErr, BlockPath, Result, Writecap};
-use btserde::{read_from, write_to};
+use btlib::{bterr, crypto::Creds, BlockPath, Result, Writecap};
+use btserde::write_to;
 use bytes::{BufMut, BytesMut};
-use core::{future::Future, marker::Send, ops::DerefMut, pin::Pin};
-use futures::{sink::Send as SendFut, SinkExt, StreamExt};
+use core::{
+    future::{ready, Future, Ready},
+    marker::Send,
+    pin::Pin,
+};
+use futures::{FutureExt, SinkExt};
 use log::error;
 use quinn::{Connection, Endpoint, RecvStream, SendStream};
 use serde::{de::DeserializeOwned, Deserialize, Serialize};
@@ -22,10 +28,11 @@ use std::{
     sync::Arc,
 };
 use tokio::{
+    runtime::Handle,
     select,
     sync::{broadcast, Mutex},
 };
-use tokio_util::codec::{Decoder, Encoder, Framed, FramedParts, FramedRead, FramedWrite};
+use tokio_util::codec::{Encoder, Framed, FramedParts};
 
 /// Returns a [Receiver] bound to the given [IpAddr] which receives messages at the bind path of
 /// the given [Writecap] of the given credentials. The returned type can be used to make
@@ -53,8 +60,7 @@ pub trait MsgCallback: Clone + Send + Sync + Unpin {
     type Arg<'de>: CallMsg<'de>
     where
         Self: 'de;
-    type Return;
-    type CallFut<'de>: Future<Output = Self::Return> + Send
+    type CallFut<'de>: Future<Output = Result<()>> + Send
     where
         Self: 'de;
     fn call<'de>(&'de self, arg: MsgReceived<Self::Arg<'de>>) -> Self::CallFut<'de>;
@@ -62,7 +68,6 @@ pub trait MsgCallback: Clone + Send + Sync + Unpin {
 
 impl<T: MsgCallback> MsgCallback for &T {
     type Arg<'de> = T::Arg<'de> where Self: 'de;
-    type Return = T::Return;
     type CallFut<'de> = T::CallFut<'de> where Self: 'de;
     fn call<'de>(&'de self, arg: MsgReceived<Self::Arg<'de>>) -> Self::CallFut<'de> {
         (*self).call(arg)
@@ -71,14 +76,11 @@ impl<T: MsgCallback> MsgCallback for &T {
 
 /// Trait for messages which can be transmitted using the call method.
 pub trait CallMsg<'de>: Serialize + Deserialize<'de> + Send + Sync {
-    type Reply: Serialize + DeserializeOwned + Send;
+    type Reply<'r>: Serialize + Deserialize<'r> + Send;
 }
 
-#[derive(Serialize, Deserialize)]
-pub enum NoReply {}
-
 /// Trait for messages which can be transmitted using the send method.
-/// Types which implement this trait should specify [NoReply] as their reply type.
+/// Types which implement this trait should specify `()` as their reply type.
 pub trait SendMsg<'de>: CallMsg<'de> {}
 
 /// An address which identifies a block on the network. An instance of this struct can be
@@ -207,36 +209,79 @@ pub trait Receiver {
 
 /// A type which can be used to transmit messages.
 pub trait Transmitter {
-    type SendFut<'s, T>: 's + Future<Output = Result<()>> + Send
+    type SendFut<'call, T>: 'call + Future<Output = Result<()>> + Send
     where
-        Self: 's,
-        T: 's + Serialize + Send;
+        Self: 'call,
+        T: 'call + SendMsg<'call>;
 
     /// Transmit a message to the connected [Receiver] without waiting for a reply.
-    fn send<'de, T: 'de + SendMsg<'de>>(&'de mut self, msg: T) -> Self::SendFut<'de, T>;
+    fn send<'call, T: 'call + SendMsg<'call>>(&'call mut self, msg: T) -> Self::SendFut<'call, T>;
 
-    type CallFut<'s, 'de, T>: 's + Future<Output = Result<T::Reply>> + Send
+    type CallFut<'call, T, F>: 'call + Future<Output = Result<F::Return>> + Send
     where
-        Self: 's,
-        T: 's + CallMsg<'de>,
-        T::Reply: 's;
+        Self: 'call,
+        T: 'call + CallMsg<'call>,
+        F: 'static + Send + Sync + DeserCallback;
 
-    /// Transmit a message to the connected [Receiver] and wait for a reply.
-    fn call<'s, 'de, T>(&'s mut self, msg: T) -> Self::CallFut<'s, 'de, T>
+    /// Transmit a message to the connected [Receiver], waits for a reply, then calls the given
+    /// [DeserCallback] with the deserialized reply.
+    fn call<'call, T, F>(&'call mut self, msg: T, callback: F) -> Self::CallFut<'call, T, F>
     where
-        T: 's + CallMsg<'de>,
-        T::Reply: 's;
-
-    type FinishFut: Future<Output = Result<()>> + Send;
-
-    /// Finish any ongoing transmissions and close the connection to the [Receiver].
-    fn finish(self) -> Self::FinishFut;
+        T: 'call + CallMsg<'call>,
+        F: 'static + Send + Sync + DeserCallback;
+
+    /// Transmits a message to the connected [Reciever], waits for a reply, then passes back the
+    /// the reply to the caller.
+    fn call_through<'call, T>(
+        &'call mut self,
+        msg: T,
+    ) -> Self::CallFut<'call, T, Passthrough<T::Reply<'call>>>
+    where
+        T: 'call + CallMsg<'call>,
+        T::Reply<'call>: 'static + Send + Sync + DeserializeOwned,
+    {
+        self.call(msg, Passthrough::new())
+    }
 
     /// Returns the address that this instance is transmitting to.
     fn addr(&self) -> &Arc<BlockAddr>;
 }
 
+pub struct Passthrough<T> {
+    phantom: PhantomData<T>,
+}
+
+impl<T> Passthrough<T> {
+    pub fn new() -> Self {
+        Self {
+            phantom: PhantomData,
+        }
+    }
+}
+
+impl<T> Default for Passthrough<T> {
+    fn default() -> Self {
+        Self::new()
+    }
+}
+
+impl<T> Clone for Passthrough<T> {
+    fn clone(&self) -> Self {
+        Self::new()
+    }
+}
+
+impl<T: 'static + Send + DeserializeOwned> DeserCallback for Passthrough<T> {
+    type Arg<'de> = T;
+    type Return = T;
+    type CallFut<'de> = Ready<T>;
+    fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
+        ready(arg)
+    }
+}
+
 /// Encodes messages using [btserde].
+#[derive(Debug)]
 struct MsgEncoder;
 
 impl MsgEncoder {
@@ -263,73 +308,26 @@ impl<T: Serialize> Encoder<T> for MsgEncoder {
     }
 }
 
-/// Decodes messages using [btserde].
-struct MsgDecoder<T>(PhantomData<T>);
-
-impl<T> MsgDecoder<T> {
-    fn new() -> Self {
-        Self(PhantomData)
-    }
-}
-
-impl<T: DeserializeOwned> Decoder for MsgDecoder<T> {
-    type Item = T;
-    type Error = btlib::Error;
-
-    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>> {
-        let mut slice: &[u8] = src.as_ref();
-        let payload_len: u64 = match read_from(&mut slice) {
-            Ok(payload_len) => payload_len,
-            Err(err) => {
-                return match err {
-                    btserde::Error::Eof => Ok(None),
-                    btserde::Error::Io(ref io_err) => match io_err.kind() {
-                        std::io::ErrorKind::UnexpectedEof => Ok(None),
-                        _ => Err(err.into()),
-                    },
-                    _ => Err(err.into()),
-                }
-            }
-        };
-        let payload_len: usize = payload_len.try_into().box_err()?;
-        if slice.len() < payload_len {
-            src.reserve(payload_len - slice.len());
-            return Ok(None);
-        }
-        let msg = read_from(&mut slice)?;
-        // Consume all the bytes that have been read out of the buffer.
-        let _ = src.split_to(std::mem::size_of::<u64>() + payload_len);
-        Ok(Some(msg))
-    }
-}
-
-type SharedFrameParts = Arc<Mutex<Option<FramedParts<SendStream, MsgEncoder>>>>;
+type FramedMsg = Framed<SendStream, MsgEncoder>;
+type ArcMutex<T> = Arc<Mutex<T>>;
 
 #[derive(Clone)]
 pub struct Replier {
-    parts: SharedFrameParts,
+    stream: ArcMutex<FramedMsg>,
 }
 
 impl Replier {
-    fn new(send_stream: SendStream) -> Self {
-        let parts = FramedParts::new::<()>(send_stream, MsgEncoder::new());
-        let parts = Arc::new(Mutex::new(Some(parts)));
-        Self { parts }
+    fn new(stream: ArcMutex<FramedMsg>) -> Self {
+        Self { stream }
     }
 
-    pub async fn reply<T: Serialize + Send>(self, reply: T) -> Result<()> {
-        let parts = self.parts;
-        let mut guard = parts.lock().await;
-        // We must ensure the parts are put back before we leave this block.
-        let parts = guard.take().unwrap();
-        let mut stream = Framed::from_parts(parts);
-        let result = stream.send(reply).await;
-        *guard = Some(stream.into_parts());
-        result
+    pub async fn reply<T: Serialize + Send>(&mut self, reply: T) -> Result<()> {
+        let mut guard = self.stream.lock().await;
+        guard.send(reply).await?;
+        Ok(())
     }
 }
 
-#[derive(Clone)]
 struct MsgRecvdCallback<F> {
     path: Arc<BlockPath>,
     replier: Replier,
@@ -337,10 +335,10 @@ struct MsgRecvdCallback<F> {
 }
 
 impl<F: MsgCallback> MsgRecvdCallback<F> {
-    fn new(path: Arc<BlockPath>, replier: Replier, inner: F) -> Self {
+    fn new(path: Arc<BlockPath>, framed_msg: ArcMutex<FramedMsg>, inner: F) -> Self {
         Self {
             path,
-            replier,
+            replier: Replier::new(framed_msg),
             inner,
         }
     }
@@ -348,9 +346,9 @@ impl<F: MsgCallback> MsgRecvdCallback<F> {
 
 impl<F: MsgCallback> DeserCallback for MsgRecvdCallback<F> {
     type Arg<'de> = Envelope<F::Arg<'de>> where Self: 'de;
-    type Return = F::Return;
-    type CallFut<'s> = F::CallFut<'s> where F: 's;
-    fn call<'de>(&'de self, arg: Envelope<F::Arg<'de>>) -> Self::CallFut<'de> {
+    type Return = Result<()>;
+    type CallFut<'de> = F::CallFut<'de> where F: 'de, Self: 'de;
+    fn call<'de>(&'de mut self, arg: Envelope<F::Arg<'de>>) -> Self::CallFut<'de> {
         let replier = match arg.kind {
             MsgKind::Call => Some(self.replier.clone()),
             MsgKind::Send => None,
@@ -440,11 +438,13 @@ impl QuicReceiver {
             let connection = unwrap_or_continue!(connecting.await, |err| error!(
                 "error accepting QUIC connection: {err}"
             ));
-            tokio::spawn(Self::handle_connection(
-                connection,
-                callback.clone(),
-                stop_rx.resubscribe(),
-            ));
+            let callback = callback.clone();
+            let stop_rx = stop_rx.resubscribe();
+            // spawn_blocking is used to allow the user supplied callback to to block without
+            // disrupting the main thread pool.
+            tokio::task::spawn_blocking(move || {
+                Handle::current().block_on(Self::handle_connection(connection, callback, stop_rx))
+            });
         }
     }
 
@@ -457,17 +457,39 @@ impl QuicReceiver {
             Self::client_path(connection.peer_identity()),
             |err| error!("failed to get client path from peer identity: {err}")
         );
-        let (send_stream, recv_stream) = unwrap_or_return!(
-            connection.accept_bi().await,
-            |err| error!("error accepting receive stream: {err}")
-        );
-        let replier = Replier::new(send_stream);
-        let callback = MsgRecvdCallback::new(client_path, replier, callback);
-        let mut msg_stream = CallbackFramed::new(recv_stream);
+        let mut frame_parts_opt: Option<FramedParts<SendStream, MsgEncoder>> = None;
         loop {
-            let decode_result = await_or_stop!(msg_stream.next(callback.clone()), stop_rx.recv());
-            if let Err(ref err) = decode_result {
-                error!("msg_stream produced an error: {err}");
+            let result = await_or_stop!(connection.accept_bi().map(Some), stop_rx.recv());
+            let (send_stream, recv_stream) =
+                unwrap_or_continue!(result, |err| error!("error accepting stream: {err}"));
+            let frame_parts = match frame_parts_opt {
+                Some(mut frame_parts) => {
+                    frame_parts.io = send_stream;
+                    frame_parts
+                }
+                None => FramedParts::new::<<<F as MsgCallback>::Arg<'_> as CallMsg>::Reply<'_>>(
+                    send_stream,
+                    MsgEncoder::new(),
+                ),
+            };
+            let framed_msg = Arc::new(Mutex::new(FramedMsg::from_parts(frame_parts)));
+            let callback =
+                MsgRecvdCallback::new(client_path.clone(), framed_msg.clone(), callback.clone());
+            let mut msg_stream = CallbackFramed::new(recv_stream);
+            let result = msg_stream
+                .next(callback)
+                .await
+                .ok_or_else(|| bterr!("client closed stream before sending a message"));
+            let msg_framed = Arc::try_unwrap(framed_msg).unwrap();
+            let msg_framed = msg_framed.into_inner();
+            frame_parts_opt = Some(msg_framed.into_parts());
+            match unwrap_or_continue!(result) {
+                Err(err) => error!("msg_stream produced an error: {err}"),
+                Ok(result) => {
+                    if let Err(err) = result {
+                        error!("callback returned an error: {err}");
+                    }
+                }
             }
         }
     }
@@ -510,10 +532,23 @@ impl Receiver for QuicReceiver {
     }
 }
 
+macro_rules! cleanup_on_err {
+    ($result:expr, $guard:ident, $parts:ident) => {
+        match $result {
+            Ok(value) => value,
+            Err(err) => {
+                *$guard = Some($parts);
+                return Err(err.into());
+            }
+        }
+    };
+}
+
 struct QuicTransmitter {
     addr: Arc<BlockAddr>,
-    sink: FramedWrite<SendStream, MsgEncoder>,
-    recv_stream: Mutex<RecvStream>,
+    connection: Connection,
+    send_parts: Mutex<Option<FramedParts<SendStream, MsgEncoder>>>,
+    recv_buf: Mutex<Option<BytesMut>>,
 }
 
 impl QuicTransmitter {
@@ -530,59 +565,78 @@ impl QuicTransmitter {
             "UNIMPORTANT",
         )?;
         let connection = connecting.await?;
-        let (send_stream, recv_stream) = connection.open_bi().await?;
-        let sink = FramedWrite::new(send_stream, MsgEncoder::new());
+        let send_parts = Mutex::new(None);
+        let recv_buf = Mutex::new(Some(BytesMut::new()));
         Ok(Self {
             addr,
-            sink,
-            recv_stream: Mutex::new(recv_stream),
+            connection,
+            send_parts,
+            recv_buf,
         })
     }
+
+    async fn transmit<T: Serialize>(&mut self, envelope: Envelope<T>) -> Result<RecvStream> {
+        let mut guard = self.send_parts.lock().await;
+        let (send_stream, recv_stream) = self.connection.open_bi().await?;
+        let parts = match guard.take() {
+            Some(mut parts) => {
+                parts.io = send_stream;
+                parts
+            }
+            None => FramedParts::new::<Envelope<T>>(send_stream, MsgEncoder::new()),
+        };
+        let mut sink = Framed::from_parts(parts);
+        let result = sink.send(envelope).await;
+        let parts = sink.into_parts();
+        cleanup_on_err!(result, guard, parts);
+        *guard = Some(parts);
+        Ok(recv_stream)
+    }
+
+    async fn call<'ser, T, F>(&'ser mut self, msg: T, callback: F) -> Result<F::Return>
+    where
+        T: 'ser + CallMsg<'ser>,
+        F: 'static + Send + Sync + DeserCallback,
+    {
+        let recv_stream = self.transmit(Envelope::call(msg)).await?;
+        let mut guard = self.recv_buf.lock().await;
+        let buffer = guard.take().unwrap();
+        let mut callback_framed = CallbackFramed::from_parts(recv_stream, buffer);
+        let result = callback_framed
+            .next(callback)
+            .await
+            .ok_or_else(|| bterr!("server hung up before sending reply"));
+        let (_, buffer) = callback_framed.into_parts();
+        let output = cleanup_on_err!(result, guard, buffer);
+        *guard = Some(buffer);
+        output
+    }
 }
 
-/// TODO: Once the "Permit impl Trait in type aliases"
-/// https://github.com/rust-lang/rust/issues/63063
-/// feature lands the future types in this implementation should be rewritten to
-/// use it.
 impl Transmitter for QuicTransmitter {
     fn addr(&self) -> &Arc<BlockAddr> {
         &self.addr
     }
 
-    type SendFut<'s, T> = SendFut<'s, FramedWrite<SendStream, MsgEncoder>, Envelope<T>>
-        where T: 's + Serialize + Send;
-
-    fn send<'de, T: 'de + SendMsg<'de>>(&'de mut self, msg: T) -> Self::SendFut<'de, T> {
-        self.sink.send(Envelope::send(msg))
-    }
-
-    type FinishFut = Pin<Box<dyn Future<Output = Result<()>> + Send>>;
+    type SendFut<'ser, T> = impl 'ser + Future<Output = Result<()>> + Send
+        where T: 'ser + SendMsg<'ser>;
 
-    fn finish(mut self) -> Self::FinishFut {
-        Box::pin(async move {
-            let steam: &mut SendStream = self.sink.get_mut();
-            steam.finish().await.map_err(|err| bterr!(err))
-        })
+    fn send<'ser, 'r, T: 'ser + SendMsg<'ser>>(&'ser mut self, msg: T) -> Self::SendFut<'ser, T> {
+        self.transmit(Envelope::send(msg))
+            .map(|result| result.map(|_| ()))
     }
 
-    type CallFut<'s, 'de, T> = Pin<Box<dyn 's + Future<Output = Result<T::Reply>> + Send>>
+    type CallFut<'ser, T, F> = impl 'ser + Future<Output = Result<F::Return>> + Send
     where
-        T: 's + CallMsg<'de>,
-        T::Reply: 's;
+        Self: 'ser,
+        T: 'ser + CallMsg<'ser>,
+        F: 'static + Send + Sync + DeserCallback;
 
-    fn call<'s, 'de, T>(&'s mut self, msg: T) -> Self::CallFut<'s, 'de, T>
+    fn call<'ser, T, F>(&'ser mut self, msg: T, callback: F) -> Self::CallFut<'ser, T, F>
     where
-        T: 's + CallMsg<'de>,
-        T::Reply: 's,
+        T: 'ser + CallMsg<'ser>,
+        F: 'static + Send + Sync + DeserCallback,
     {
-        Box::pin(async move {
-            self.sink.send(Envelope::call(msg)).await?;
-            let mut guard = self.recv_stream.lock().await;
-            let mut source = FramedRead::new(guard.deref_mut(), MsgDecoder::<T::Reply>::new());
-            source
-                .next()
-                .await
-                .ok_or_else(|| bterr!("server hung up before sending reply"))?
-        })
+        self.call(msg, callback)
     }
 }

+ 134 - 13
crates/btmsg/tests/tests.rs

@@ -1,16 +1,19 @@
+#![feature(type_alias_impl_trait)]
+
 use btmsg::*;
 
 use btlib::{
     crypto::{ConcreteCreds, Creds, CredsPriv},
     BlockPath, Epoch, Principal, Principaled,
 };
-use core::future::Future;
+use core::future::{ready, Future, Ready};
 use ctor::ctor;
 use lazy_static::lazy_static;
 use serde::{Deserialize, Serialize};
 use std::{
+    io::Write,
     net::{IpAddr, Ipv6Addr},
-    sync::{Arc, Mutex},
+    sync::{Arc, Mutex as SyncMutex},
     time::Duration,
 };
 use tokio::sync::mpsc::{self, Sender};
@@ -19,7 +22,20 @@ use tokio::sync::mpsc::{self, Sender};
 fn setup_logging() {
     use env_logger::Env;
     let env = Env::default().default_filter_or("ERROR");
-    env_logger::init_from_env(env);
+    env_logger::builder()
+        .format(|fmt, record| {
+            writeln!(
+                fmt,
+                "[{} {} {}:{}] {}",
+                chrono::Utc::now().to_rfc3339(),
+                record.level(),
+                record.file().unwrap_or("(unknown)"),
+                record.line().unwrap_or(u32::MAX),
+                record.args(),
+            )
+        })
+        .parse_env(env)
+        .init();
 }
 
 lazy_static! {
@@ -57,7 +73,7 @@ enum Msg<'a> {
 }
 
 impl<'a> CallMsg<'a> for Msg<'a> {
-    type Reply = Reply;
+    type Reply<'b> = Reply;
 }
 
 impl<'a> SendMsg<'a> for Msg<'a> {}
@@ -98,9 +114,10 @@ impl<S: 'static + Send, Fut: Send + Future> Delegate<S, Fut> {
     }
 }
 
-impl<S: 'static + Send, Fut: Send + Future> MsgCallback for Delegate<S, Fut> {
+impl<S: 'static + Send, Fut: Send + Future<Output = btlib::Result<()>>> MsgCallback
+    for Delegate<S, Fut>
+{
     type Arg<'de> = Msg<'de> where Self: 'de;
-    type Return = Fut::Output;
     type CallFut<'s> = Fut where Fut: 's;
     fn call<'de>(&'de self, arg: MsgReceived<Self::Arg<'de>>) -> Self::CallFut<'de> {
         (self.func)(arg, self.sender.clone())
@@ -136,7 +153,7 @@ async fn proc_tx_rx<F: 'static + MsgCallback>(func: F) -> (impl Transmitter, imp
 
 async fn file_server() -> (impl Transmitter, impl Receiver) {
     let (sender, _) = mpsc::channel::<()>(1);
-    let file = Arc::new(Mutex::new([0, 1, 2, 3, 4, 5, 6, 7]));
+    let file = Arc::new(SyncMutex::new([0, 1, 2, 3, 4, 5, 6, 7]));
     proc_tx_rx(Delegate::new(
         sender,
         move |mut received: MsgReceived<Msg<'_>>, _| {
@@ -161,7 +178,7 @@ async fn file_server() -> (impl Transmitter, impl Receiver) {
                 }
                 _ => Reply::Fail,
             };
-            let replier = received.take_replier().unwrap();
+            let mut replier = received.take_replier().unwrap();
             async move { replier.reply(reply_body).await }
         },
     ))
@@ -194,6 +211,7 @@ async fn message_received_is_message_sent() {
             let sender = sender.clone();
             async move {
                 sender.send(passed).await.unwrap();
+                Ok(())
             }
         },
     ))
@@ -214,6 +232,7 @@ async fn message_received_from_path_is_correct() {
             let sender = sender.clone();
             async move {
                 sender.send(path).await.unwrap();
+                Ok(())
             }
         },
     ))
@@ -228,7 +247,7 @@ async fn message_received_from_path_is_correct() {
 async fn reply_to_read() {
     let (mut sender, _receiver) = file_server().await;
     let reply = sender
-        .call(Msg::Read { offset: 2, size: 2 })
+        .call_through::<Msg>(Msg::Read { offset: 2, size: 2 })
         .await
         .unwrap();
     if let Reply::ReadReply { offset, buf } = reply {
@@ -244,7 +263,7 @@ async fn call_twice() {
     let (mut sender, _receiver) = file_server().await;
 
     let reply = sender
-        .call(Msg::Write {
+        .call_through::<Msg>(Msg::Write {
             offset: 1,
             buf: &[1, 1],
         })
@@ -256,7 +275,7 @@ async fn call_twice() {
         panic!("reply was not the right type");
     };
     let reply = sender
-        .call(Msg::Read { offset: 1, size: 2 })
+        .call_through::<Msg>(Msg::Read { offset: 1, size: 2 })
         .await
         .unwrap();
     if let Reply::ReadReply { offset, buf } = reply {
@@ -269,14 +288,14 @@ async fn call_twice() {
 
 #[tokio::test]
 async fn separate_transmitter() {
-    let (_senderx, receiver) = file_server().await;
+    let (_sender, receiver) = file_server().await;
     let creds = proc_creds();
     let mut transmitter = transmitter(receiver.addr().clone(), Arc::new(creds))
         .await
         .unwrap();
 
     let reply = transmitter
-        .call(Msg::Write {
+        .call_through::<Msg>(Msg::Write {
             offset: 5,
             buf: &[7, 7, 7],
         })
@@ -289,3 +308,105 @@ async fn separate_transmitter() {
     };
     assert!(matched);
 }
+
+#[derive(Serialize, Deserialize)]
+struct Read {
+    offset: usize,
+    size: usize,
+}
+
+#[derive(Serialize, Deserialize)]
+struct ReadReply<'a> {
+    buf: &'a [u8],
+}
+
+impl<'a> CallMsg<'a> for Read {
+    type Reply<'b> = ReadReply<'b>;
+}
+
+#[derive(Clone)]
+struct ReadChecker<'a> {
+    expected: &'a [u8],
+}
+
+impl<'a> ReadChecker<'a> {
+    fn new(expected: &'a [u8]) -> Self {
+        Self { expected }
+    }
+}
+
+impl<'a> DeserCallback for ReadChecker<'a> {
+    type Arg<'de> = ReadReply<'de> where Self: 'de;
+    type Return = bool;
+    type CallFut<'s> = Ready<bool> where Self: 's;
+    fn call<'de>(&'de mut self, arg: Self::Arg<'de>) -> Self::CallFut<'de> {
+        ready(self.expected == arg.buf)
+    }
+}
+
+trait ActionFn<Arg, Fut: Send + Future>: Send + Sync + Fn(MsgReceived<Arg>) -> Fut {}
+
+impl<Arg, Fut: Send + Future, T: Send + Sync + Fn(MsgReceived<Arg>) -> Fut> ActionFn<Arg, Fut>
+    for T
+{
+}
+
+struct Action<Arg, Fut> {
+    func: Arc<dyn ActionFn<Arg, Fut>>,
+}
+
+impl<Arg, Fut: Send + Future> Action<Arg, Fut> {
+    fn new<F: 'static + ActionFn<Arg, Fut>>(func: F) -> Self {
+        Self {
+            func: Arc::new(func),
+        }
+    }
+}
+
+impl<Arg, Fut> Clone for Action<Arg, Fut> {
+    fn clone(&self) -> Self {
+        Self {
+            func: self.func.clone(),
+        }
+    }
+}
+
+impl<Arg: for<'a> CallMsg<'a>, Fut: Send + Future<Output = btlib::Result<()>>> MsgCallback
+    for Action<Arg, Fut>
+{
+    type Arg<'de> = Arg where Arg: 'de, Fut: 'de;
+    type CallFut<'de> = Fut where Arg: 'de, Fut: 'de;
+    fn call<'de>(&'de self, arg: MsgReceived<Self::Arg<'de>>) -> Self::CallFut<'de> {
+        (self.func)(arg)
+    }
+}
+
+async fn read_server() -> (impl Transmitter, impl Receiver) {
+    let file = [0, 1, 2, 3, 4, 5, 6, 7];
+    proc_tx_rx(Action::new(move |mut msg: MsgReceived<Read>| async move {
+        let body = msg.body();
+        let start = body.offset;
+        let end = start + body.size;
+        let buf = &file[start..end];
+        let mut replier = msg.take_replier().unwrap();
+        replier.reply(ReadReply { buf }).await
+    }))
+    .await
+}
+
+#[tokio::test]
+async fn call_with_lifetime() {
+    let (mut sender, _receiver) = read_server().await;
+
+    let correct_one = sender
+        .call(Read { offset: 2, size: 3 }, ReadChecker::new(&[2, 3, 4]))
+        .await
+        .unwrap();
+    let correct_two = sender
+        .call(Read { offset: 0, size: 2 }, ReadChecker::new(&[0, 1]))
+        .await
+        .unwrap();
+
+    assert!(correct_one);
+    assert!(correct_two);
+}