Browse Source

Merge branch 'LocalFsAsyncConversion'

Matthew Carr 2 years ago
parent
commit
6e94b66443

+ 0 - 21
Cargo.lock

@@ -189,7 +189,6 @@ dependencies = [
  "paste",
  "positioned-io",
  "serde",
- "swimmer",
  "tokio",
 ]
 
@@ -200,7 +199,6 @@ dependencies = [
  "btfproto",
  "btlib",
  "btserde",
- "bytes",
  "lazy_static",
  "libc",
  "tempdir",
@@ -234,7 +232,6 @@ dependencies = [
  "btfproto-tests",
  "btlib",
  "btserde",
- "bytes",
  "ctor",
  "env_logger",
  "fuse-backend-rs",
@@ -1876,15 +1873,6 @@ dependencies = [
  "syn",
 ]
 
-[[package]]
-name = "swimmer"
-version = "0.3.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "963def60929892c4b13817d852a02ae7516d43e5aa8e0eeb560f580b1ce1e157"
-dependencies = [
- "thread_local",
-]
-
 [[package]]
 name = "swtpm-harness"
 version = "0.1.0"
@@ -1980,15 +1968,6 @@ dependencies = [
  "syn",
 ]
 
-[[package]]
-name = "thread_local"
-version = "0.3.6"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c6b53e329000edc2b34dbe8545fd20e55a333362d0a321909685a19bd28c3f1b"
-dependencies = [
- "lazy_static",
-]
-
 [[package]]
 name = "time"
 version = "0.1.44"

+ 0 - 1
crates/btfproto-tests/Cargo.toml

@@ -11,6 +11,5 @@ btfproto = { path = "../btfproto" }
 btserde = { path = "../btserde" }
 tempdir = { version = "0.3.7" }
 lazy_static = { version = "1.4.0" }
-bytes = { version = "1.3.0" }
 libc = { version = "0.2.137" }
 tokio = { version = "1.24.2" }

+ 161 - 102
crates/btfproto-tests/src/local_fs_tests.rs

@@ -80,7 +80,9 @@ impl LocalFsTest {
             .writecap()
             .ok_or(BlockError::MissingWritecap)
             .unwrap();
-        let fs = LocalFs::new_empty(path, 0, root_creds, ModeAuthorizer).unwrap();
+        let fs = LocalFs::new_empty(path, 0, root_creds, ModeAuthorizer)
+            .await
+            .unwrap();
 
         let proc_rec = IssuedProcRec {
             addr: IpAddr::V6(Ipv6Addr::LOCALHOST),
@@ -122,9 +124,9 @@ mod tests {
 
     use btfproto::{local_fs::Error, msg::*};
     use btlib::{Inode, Result, SECTOR_SZ_DEFAULT};
-    use bytes::BytesMut;
     use std::{
         io::{self, Cursor, Write as IoWrite},
+        ops::Deref,
         sync::Arc,
     };
 
@@ -148,30 +150,24 @@ mod tests {
 
         const LEN: usize = 32;
         let expected = [1u8; LEN];
-        let WriteReply { written, .. } = bt
-            .write(
-                from,
-                inode,
-                handle,
-                0,
-                expected.len() as u64,
-                expected.as_slice(),
-            )
-            .await
-            .unwrap();
+        let write_msg = Write {
+            inode,
+            handle,
+            offset: 0,
+            data: expected.as_slice(),
+        };
+        let WriteReply { written, .. } = bt.write(from, write_msg).await.unwrap();
         assert_eq!(LEN as u64, written);
 
-        let mut actual = [0u8; LEN];
         let read_msg = Read {
             inode,
             handle,
             offset: 0,
             size: LEN as u64,
         };
-        bt.read(from, read_msg, |data| actual.copy_from_slice(data))
-            .unwrap();
+        let guard = bt.read(from, read_msg).await.unwrap();
 
-        assert_eq!(expected, actual)
+        assert_eq!(expected, guard.deref())
     }
 
     #[tokio::test]
@@ -220,10 +216,13 @@ mod tests {
             };
             let CreateReply { handle, inode, .. } = bt.create(from, create_msg).await.unwrap();
 
-            let WriteReply { written, .. } = bt
-                .write(from, inode, handle, 0, EXPECTED.len() as u64, EXPECTED)
-                .await
-                .unwrap();
+            let write_msg = Write {
+                inode,
+                handle,
+                offset: 0,
+                data: EXPECTED,
+            };
+            let WriteReply { written, .. } = bt.write(from, write_msg).await.unwrap();
             assert_eq!(EXPECTED.len() as u64, written);
 
             let close_msg = Close { inode, handle };
@@ -246,17 +245,15 @@ mod tests {
         };
         let OpenReply { handle, .. } = bt.open(from, open_msg).await.unwrap();
 
-        let mut actual = BytesMut::new();
         let read_msg = Read {
             inode,
             handle,
             offset: 0,
             size: EXPECTED.len() as u64,
         };
-        bt.read(from, read_msg, |data| actual.extend_from_slice(data))
-            .unwrap();
+        let guard = bt.read(from, read_msg).await.unwrap();
 
-        assert_eq!(EXPECTED, &actual)
+        assert_eq!(EXPECTED, guard.deref())
     }
 
     /// Tests that an error is returned by the `Blocktree::write` method if the file was opened
@@ -287,9 +284,13 @@ mod tests {
         let OpenReply { handle, .. } = bt.open(from, open_msg).await.unwrap();
 
         let data = [1u8; 32];
-        let result = bt
-            .write(from, inode, handle, 0, data.len() as u64, data.as_slice())
-            .await;
+        let write_msg = Write {
+            inode,
+            handle,
+            offset: 0,
+            data: data.as_slice(),
+        };
+        let result = bt.write(from, write_msg).await;
 
         let err = result.err().unwrap();
         let err = err.downcast::<Error>().unwrap();
@@ -380,23 +381,24 @@ mod tests {
             umask: 0,
         };
         let CreateReply { inode, handle, .. } = bt.create(from, create_msg).await.unwrap();
-        let WriteReply { written, .. } = bt
-            .write(from, inode, handle, 0, DATA.len() as u64, DATA.as_slice())
-            .await
-            .unwrap();
+        let write_msg = Write {
+            inode,
+            handle,
+            offset: 0,
+            data: DATA.as_slice(),
+        };
+        let WriteReply { written, .. } = bt.write(from, write_msg).await.unwrap();
         assert_eq!(DATA.len() as u64, written);
         const SIZE: usize = DATA.len() / 2;
-        let mut actual = Vec::with_capacity(SIZE);
         let read_msg = Read {
             inode,
             handle,
             offset: 0,
             size: SIZE as u64,
         };
-        bt.read(from, read_msg, |data| actual.extend_from_slice(data))
-            .unwrap();
+        let guard = bt.read(from, read_msg).await.unwrap();
 
-        assert_eq!(&[0, 1, 2, 3], actual.as_slice());
+        assert_eq!(&[0, 1, 2, 3], guard.deref());
     }
 
     /// Returns an integer array starting at the given value and increasing by one for each
@@ -434,36 +436,35 @@ mod tests {
             umask: 0,
         };
         let CreateReply { inode, handle, .. } = bt.create(from, create_msg).await.unwrap();
-        let WriteReply { written, .. } = bt
-            .write(from, inode, handle, 0, DATA.len() as u64, DATA.as_slice())
-            .await
-            .unwrap();
+        let write_msg = Write {
+            inode,
+            handle,
+            offset: 0,
+            data: DATA.as_slice(),
+        };
+        let WriteReply { written, .. } = bt.write(from, write_msg).await.unwrap();
         assert_eq!(DATA.len() as u64, written);
-        let case = Box::new(case);
-        let cb = Arc::new(Box::new(move |offset: usize| {
-            // Notice that we have concurrent reads to different offsets using the same handle.
-            // Without proper synchronization, this shouldn't work.
-            let mut actual = Vec::with_capacity(SIZE);
-            let read_msg = Read {
-                inode,
-                handle,
-                offset: offset as u64,
-                size: SIZE as u64,
-            };
-            case.fs
-                .read(case.from(), read_msg, |data| actual.extend_from_slice(data))
-                .unwrap();
-            let expected = integer_array::<SIZE>(offset as u8);
-            assert_eq!(&expected, actual.as_slice());
-        }));
+        let case = Arc::new(case);
 
         let mut handles = Vec::with_capacity(NREADS);
         for offset in (0..NREADS).map(|e| e * SIZE) {
-            let thread_cb = cb.clone();
-            handles.push(std::thread::spawn(move || thread_cb(offset)));
+            let case = case.clone();
+            handles.push(tokio::spawn(async move {
+                // Notice that we have concurrent reads to different offsets using the same handle.
+                // Without proper synchronization, this shouldn't work.
+                let read_msg = Read {
+                    inode,
+                    handle,
+                    offset: offset as u64,
+                    size: SIZE as u64,
+                };
+                let guard = case.fs.read(case.from(), read_msg).await.unwrap();
+                let expected = integer_array::<SIZE>(offset as u8);
+                assert_eq!(&expected, guard.deref());
+            }));
         }
         for handle in handles {
-            handle.join().unwrap();
+            handle.await.unwrap();
         }
     }
 
@@ -665,9 +666,13 @@ mod tests {
             umask: 0,
         };
         let CreateReply { inode, handle, .. } = bt.create(owner, create_msg).await.unwrap();
-        let result = bt
-            .write(&other, inode, handle, 0, 3, [1, 2, 3].as_slice())
-            .await;
+        let write_msg = Write {
+            inode,
+            handle,
+            offset: 0,
+            data: [1, 2, 3].as_slice(),
+        };
+        let result = bt.write(&other, write_msg).await;
 
         let err = result.err().unwrap().downcast::<Error>().unwrap();
         let matched = if let Error::WrongOwner = err {
@@ -693,30 +698,29 @@ mod tests {
             umask: 0,
         };
         let CreateReply { inode, handle, .. } = bt.create(from, create_msg).await.unwrap();
-        let mut actual = [0u8; 8];
+        const LEN: u64 = 8;
         let alloc_msg = Allocate {
             inode,
             handle,
             offset: None,
-            size: actual.len() as u64,
+            size: LEN,
         };
         bt.allocate(from, alloc_msg).await.unwrap();
-        let read_msg = Read {
-            inode,
-            handle,
-            offset: 0,
-            size: actual.len() as u64,
-        };
-        bt.read(from, read_msg, |data| actual.copy_from_slice(data))
-            .unwrap();
         let read_meta_msg = ReadMeta {
             inode,
             handle: Some(handle),
         };
         let ReadMetaReply { attrs, .. } = bt.read_meta(from, read_meta_msg).await.unwrap();
+        let read_msg = Read {
+            inode,
+            handle,
+            offset: 0,
+            size: LEN,
+        };
+        let guard = bt.read(from, read_msg).await.unwrap();
 
-        assert_eq!([0u8; 8], actual);
-        assert_eq!(actual.len() as u64, attrs.size);
+        assert_eq!([0u8; 8], guard.deref());
+        assert_eq!(guard.len() as u64, attrs.size);
     }
 
     #[tokio::test]
@@ -751,12 +755,10 @@ mod tests {
                 offset: 0,
                 size,
             };
-            let cb = |data: &[u8]| {
-                actual.write(data)?;
-                size -= data.len() as u64;
-                Ok::<_, io::Error>(())
-            };
-            bt.read(from, read_msg, cb).unwrap().unwrap();
+            let guard = bt.read(from, read_msg).await.unwrap();
+            let data = guard.deref();
+            actual.write(data).unwrap();
+            size -= data.len() as u64;
         }
         let read_meta_msg = ReadMeta {
             inode,
@@ -792,22 +794,20 @@ mod tests {
             size,
         };
         bt.allocate(from, alloc_msg).await.unwrap();
-        let mut actual = vec![0u8; LEN];
+        let read_meta_msg = ReadMeta {
+            inode,
+            handle: Some(handle),
+        };
+        let ReadMetaReply { attrs, .. } = bt.read_meta(from, read_meta_msg).await.unwrap();
         let read_msg = Read {
             inode,
             handle,
             offset: 0,
             size,
         };
-        bt.read(from, read_msg, |data| actual.copy_from_slice(data))
-            .unwrap();
-        let read_meta_msg = ReadMeta {
-            inode,
-            handle: Some(handle),
-        };
-        let ReadMetaReply { attrs, .. } = bt.read_meta(from, read_meta_msg).await.unwrap();
+        let guard = bt.read(from, read_msg).await.unwrap();
 
-        assert_eq!(vec![0u8; LEN], actual);
+        assert_eq!(vec![0u8; LEN], guard.deref());
         assert_eq!(LEN as u64, attrs.size);
     }
 
@@ -829,10 +829,13 @@ mod tests {
         };
         let CreateReply { inode, handle, .. } = bt.create(from, create_msg).await.unwrap();
         const LEN: usize = 8;
-        let WriteReply { written, .. } = bt
-            .write(from, inode, handle, 0, LEN as u64, [1u8; LEN].as_slice())
-            .await
-            .unwrap();
+        let write_msg = Write {
+            inode,
+            handle,
+            offset: 0,
+            data: [1u8; LEN].as_slice(),
+        };
+        let WriteReply { written, .. } = bt.write(from, write_msg).await.unwrap();
         assert_eq!(LEN as u64, written);
         let alloc_msg = Allocate {
             inode,
@@ -841,22 +844,78 @@ mod tests {
             size: (LEN / 2) as u64,
         };
         bt.allocate(from, alloc_msg).await.unwrap();
-        let mut actual = [0u8; LEN];
+        let read_meta_msg = ReadMeta {
+            inode,
+            handle: Some(handle),
+        };
+        let ReadMetaReply { attrs, .. } = bt.read_meta(from, read_meta_msg).await.unwrap();
         let read_msg = Read {
             inode,
             handle,
             offset: 0,
             size: LEN as u64,
         };
-        bt.read(from, read_msg, |data| actual.copy_from_slice(data))
-            .unwrap();
-        let read_meta_msg = ReadMeta {
+        let actual = bt.read(from, read_msg).await.unwrap();
+
+        assert_eq!([1u8; LEN], actual.deref());
+        assert_eq!(LEN as u64, attrs.size);
+    }
+
+    #[tokio::test]
+    async fn read_at_non_current_position() {
+        const FILENAME: &str = "MANIFESTO.rtf";
+        let case = LocalFsTest::new_empty().await;
+        let bt = &case.fs;
+        let from = case.from();
+
+        let msg = Create {
+            parent: SpecInodes::RootDir.into(),
+            name: FILENAME,
+            flags: FlagValue::ReadWrite.into(),
+            mode: 0o644,
+            umask: 0,
+        };
+        let CreateReply {
             inode,
-            handle: Some(handle),
+            handle,
+            entry,
+            ..
+        } = bt.create(from, msg).await.unwrap();
+        let sect_sz64 = entry.attr.sect_sz;
+        let sect_sz: usize = sect_sz64.try_into().unwrap();
+        let mut data = vec![1u8; sect_sz];
+        let msg = Write {
+            inode,
+            handle,
+            offset: 0,
+            data: data.as_slice(),
         };
-        let ReadMetaReply { attrs, .. } = bt.read_meta(from, read_meta_msg).await.unwrap();
+        let WriteReply { written, .. } = bt.write(from, msg).await.unwrap();
+        assert_eq!(sect_sz64, written);
+        data.truncate(0);
+        data.extend(std::iter::repeat(2).take(sect_sz));
+        let msg = Write {
+            inode,
+            handle,
+            offset: sect_sz64,
+            data: data.as_slice(),
+        };
+        let WriteReply { written, .. } = bt.write(from, msg).await.unwrap();
+        assert_eq!(sect_sz64, written);
+        // The Accessor for  this block should now have the second sector loaded, so it will have to
+        // seek back to the first in order to respond to this read request.
+        let msg = Read {
+            inode,
+            handle,
+            offset: 0,
+            size: sect_sz64,
+        };
+        let guard = bt.read(from, msg).await.unwrap();
 
-        assert_eq!([1u8; LEN], actual);
-        assert_eq!(LEN as u64, attrs.size);
+        assert!(guard
+            .deref()
+            .iter()
+            .map(|e| *e)
+            .eq(std::iter::repeat(1u8).take(sect_sz)));
     }
 }

+ 0 - 1
crates/btfproto/Cargo.toml

@@ -24,4 +24,3 @@ positioned-io = { version = "0.3.1", optional = true }
 fuse-backend-rs = { version = "0.9.6", optional = true }
 btserde = { path = "../btserde", optional = true }
 bytes = { version = "1.3.0", optional = true }
-swimmer = "0.3.0"

File diff suppressed because it is too large
+ 429 - 333
crates/btfproto/src/local_fs.rs


+ 3 - 3
crates/btfproto/src/msg.rs

@@ -25,7 +25,7 @@ pub enum FsMsg<'a> {
     Open(Open),
     Read(Read),
     #[serde(borrow)]
-    Write(Write<'a>),
+    Write(Write<&'a [u8]>),
     Flush(Flush),
     ReadDir(ReadDir),
     #[serde(borrow)]
@@ -459,11 +459,11 @@ pub struct ReadReply<'a> {
 }
 
 #[derive(Serialize, Deserialize)]
-pub struct Write<'a> {
+pub struct Write<R> {
     pub inode: Inode,
     pub handle: Handle,
     pub offset: u64,
-    pub data: &'a [u8],
+    pub data: R,
 }
 
 #[derive(Serialize, Deserialize)]

+ 29 - 73
crates/btfproto/src/server.rs

@@ -1,14 +1,10 @@
 // SPDX-License-Identifier: AGPL-3.0-or-later
-use crate::{
-    msg::{Read as ReadMsg, *},
-    Handle, Inode,
-};
+use crate::msg::{Read as ReadMsg, *};
 
 use btlib::{crypto::Creds, BlockPath, Result};
 use btmsg::{receiver, MsgCallback, MsgReceived, Receiver};
 use core::future::Future;
-use std::{io::Read, net::IpAddr, sync::Arc};
-use swimmer::Pool;
+use std::{net::IpAddr, ops::Deref, sync::Arc};
 
 pub trait FsProvider: Send + Sync {
     type LookupFut<'c>: Send + Future<Output = Result<LookupReply>>
@@ -26,24 +22,20 @@ pub trait FsProvider: Send + Sync {
         Self: 'c;
     fn open<'c>(&'c self, from: &'c Arc<BlockPath>, msg: Open) -> Self::OpenFut<'c>;
 
-    fn read<'c, R, F>(&'c self, from: &'c Arc<BlockPath>, msg: ReadMsg, callback: F) -> Result<R>
-    where
-        F: 'c + FnOnce(&[u8]) -> R;
-
-    type WriteFut<'c>: Send + Future<Output = Result<WriteReply>>
+    type ReadGuard: Send + Sync + Deref<Target = [u8]>;
+    type ReadFut<'c>: Send + Future<Output = Result<Self::ReadGuard>>
     where
         Self: 'c;
-    fn write<'c, R>(
-        &'c self,
-        from: &'c Arc<BlockPath>,
-        inode: Inode,
-        handle: Handle,
-        offset: u64,
-        size: u64,
-        reader: R,
-    ) -> Self::WriteFut<'c>
+    /// Reads from the file specified in the given message.
+    /// ### WARNING
+    /// The returned guard must be dropped before another method is called on this provider.
+    /// Otherwise deadlock _will_ occur.
+    fn read<'c>(&'c self, from: &'c Arc<BlockPath>, msg: ReadMsg) -> Self::ReadFut<'c>;
+
+    type WriteFut<'r>: Send + Future<Output = Result<WriteReply>>
     where
-        R: 'c + Read;
+        Self: 'r;
+    fn write<'c>(&'c self, from: &'c Arc<BlockPath>, write: Write<&'c [u8]>) -> Self::WriteFut<'c>;
 
     type FlushFut<'c>: Send + Future<Output = Result<()>>
     where
@@ -58,12 +50,12 @@ pub trait FsProvider: Send + Sync {
     type LinkFut<'c>: Send + Future<Output = Result<LinkReply>>
     where
         Self: 'c;
-    fn link<'c>(&'c self, from: &'c Arc<BlockPath>, msg: Link) -> Self::LinkFut<'c>;
+    fn link<'c>(&'c self, from: &'c Arc<BlockPath>, msg: Link<'c>) -> Self::LinkFut<'c>;
 
     type UnlinkFut<'c>: Send + Future<Output = Result<()>>
     where
         Self: 'c;
-    fn unlink<'c>(&'c self, from: &'c Arc<BlockPath>, msg: Unlink) -> Self::UnlinkFut<'c>;
+    fn unlink<'c>(&'c self, from: &'c Arc<BlockPath>, msg: Unlink<'c>) -> Self::UnlinkFut<'c>;
 
     type ReadMetaFut<'c>: Send + Future<Output = Result<ReadMetaReply>>
     where
@@ -136,27 +128,15 @@ impl<P: FsProvider> FsProvider for &P {
         (*self).open(from, msg)
     }
 
-    fn read<'c, R, F>(&'c self, from: &'c Arc<BlockPath>, msg: ReadMsg, callback: F) -> Result<R>
-    where
-        F: 'c + FnOnce(&[u8]) -> R,
-    {
-        (*self).read(from, msg, callback)
+    type ReadGuard = P::ReadGuard;
+    type ReadFut<'c> = P::ReadFut<'c> where Self: 'c;
+    fn read<'c>(&'c self, from: &'c Arc<BlockPath>, msg: ReadMsg) -> Self::ReadFut<'c> {
+        (*self).read(from, msg)
     }
 
-    type WriteFut<'c> = P::WriteFut<'c> where Self: 'c;
-    fn write<'c, R>(
-        &'c self,
-        from: &'c Arc<BlockPath>,
-        inode: Inode,
-        handle: Handle,
-        offset: u64,
-        size: u64,
-        reader: R,
-    ) -> Self::WriteFut<'c>
-    where
-        R: 'c + Read,
-    {
-        (*self).write(from, inode, handle, offset, size, reader)
+    type WriteFut<'r> = P::WriteFut<'r> where Self: 'r;
+    fn write<'c>(&'c self, from: &'c Arc<BlockPath>, write: Write<&'c [u8]>) -> Self::WriteFut<'c> {
+        (*self).write(from, write)
     }
 
     type FlushFut<'c> = P::FlushFut<'c> where Self: 'c;
@@ -170,12 +150,12 @@ impl<P: FsProvider> FsProvider for &P {
     }
 
     type LinkFut<'c> = P::LinkFut<'c> where Self: 'c;
-    fn link<'c>(&'c self, from: &'c Arc<BlockPath>, msg: Link) -> Self::LinkFut<'c> {
+    fn link<'c>(&'c self, from: &'c Arc<BlockPath>, msg: Link<'c>) -> Self::LinkFut<'c> {
         (*self).link(from, msg)
     }
 
     type UnlinkFut<'c> = P::UnlinkFut<'c> where Self: 'c;
-    fn unlink<'c>(&'c self, from: &'c Arc<BlockPath>, msg: Unlink) -> Self::UnlinkFut<'c> {
+    fn unlink<'c>(&'c self, from: &'c Arc<BlockPath>, msg: Unlink<'c>) -> Self::UnlinkFut<'c> {
         (*self).unlink(from, msg)
     }
 
@@ -238,22 +218,12 @@ impl<P: FsProvider> FsProvider for &P {
 }
 
 struct ServerCallback<P> {
-    pool: Arc<Pool<Vec<u8>>>,
     provider: Arc<P>,
 }
 
 impl<P> ServerCallback<P> {
-    const POOL_SZ: usize = 8;
-
     fn new(provider: Arc<P>) -> Self {
-        let pool = swimmer::builder()
-            .with_starting_size(Self::POOL_SZ)
-            .with_supplier(|| Vec::with_capacity(btlib::SECTOR_SZ_DEFAULT))
-            .build();
-        Self {
-            provider,
-            pool: Arc::new(pool),
-        }
+        Self { provider }
     }
 }
 
@@ -261,7 +231,6 @@ impl<P> Clone for ServerCallback<P> {
     fn clone(&self) -> Self {
         Self {
             provider: self.provider.clone(),
-            pool: self.pool.clone(),
         }
     }
 }
@@ -272,33 +241,20 @@ impl<P: 'static + Send + Sync + FsProvider> MsgCallback for ServerCallback<P> {
     fn call<'de>(&'de self, arg: MsgReceived<FsMsg<'de>>) -> Self::CallFut<'de> {
         async move {
             let (from, body, replier) = arg.into_parts();
-            let provider = &self.provider;
+            let provider = self.provider.as_ref();
             let reply = match body {
                 FsMsg::Lookup(lookup) => FsReply::Lookup(provider.lookup(&from, lookup).await?),
                 FsMsg::Create(create) => FsReply::Create(provider.create(&from, create).await?),
                 FsMsg::Open(open) => FsReply::Open(provider.open(&from, open).await?),
                 FsMsg::Read(read) => {
-                    let buf = provider.read(&from, read, move |data| {
-                        let mut buf = self.pool.get();
-                        buf.extend_from_slice(data);
-                        buf
-                    })?;
+                    let guard = provider.read(&from, read).await?;
                     let mut replier = replier.unwrap();
                     replier
-                        .reply(FsReply::Read(ReadReply { data: &buf }))
+                        .reply(FsReply::Read(ReadReply { data: &guard }))
                         .await?;
                     return Ok(());
                 }
-                FsMsg::Write(Write {
-                    inode,
-                    handle,
-                    offset,
-                    data,
-                }) => FsReply::Write(
-                    provider
-                        .write(&from, inode, handle, offset, data.len() as u64, data)
-                        .await?,
-                ),
+                FsMsg::Write(write) => FsReply::Write(provider.write(&from, write).await?),
                 FsMsg::Flush(flush) => FsReply::Ack(provider.flush(&from, flush).await?),
                 FsMsg::ReadDir(read_dir) => {
                     FsReply::ReadDir(provider.read_dir(&from, read_dir).await?)

+ 11 - 6
crates/btfsd/src/main.rs

@@ -32,19 +32,24 @@ const DEFAULT_CONFIG: ConfigRef<'static> = ConfigRef {
     block_dir: "./bt",
 };
 
-fn provider<C: 'static + Send + Sync + Creds>(block_dir: PathBuf, creds: C) -> impl FsProvider {
+async fn provider<C: 'static + Send + Sync + Creds>(
+    block_dir: PathBuf,
+    creds: C,
+) -> impl FsProvider {
     if block_dir.exists() {
         LocalFs::new_existing(block_dir, creds, ModeAuthorizer).unwrap()
     } else {
         std::fs::create_dir_all(&block_dir).unwrap();
-        LocalFs::new_empty(block_dir, 0, creds, ModeAuthorizer).unwrap()
+        LocalFs::new_empty(block_dir, 0, creds, ModeAuthorizer)
+            .await
+            .unwrap()
     }
 }
 
-fn receiver(config: Config) -> impl Receiver {
+async fn receiver(config: Config) -> impl Receiver {
     let cred_store = TpmCredStore::from_tabrmd(&config.tabrmd, config.tpm_state_path).unwrap();
     let node_creds = cred_store.node_creds().unwrap();
-    let provider = Arc::new(provider(config.block_dir, node_creds.clone()));
+    let provider = Arc::new(provider(config.block_dir, node_creds.clone()).await);
     new_fs_server(config.ip_addr, Arc::new(node_creds), provider).unwrap()
 }
 
@@ -64,7 +69,7 @@ async fn main() {
         .with_tpm_state_path(tpm_state_path)
         .with_block_dir(block_dir)
         .build();
-    let receiver = receiver(config);
+    let receiver = receiver(config).await;
     receiver.complete().unwrap().await.unwrap();
 }
 
@@ -115,7 +120,7 @@ mod tests {
             tpm_state_path: harness.swtpm().state_path().to_owned(),
             block_dir: dir.path().join(BT_DIR),
         };
-        let rx = receiver(config);
+        let rx = receiver(config).await;
         let tx = rx.transmitter(rx.addr().clone()).await.unwrap();
         let client = FsClient::new(tx);
         TestCase {

+ 0 - 1
crates/btfuse/Cargo.toml

@@ -16,7 +16,6 @@ log = "0.4.17"
 env_logger = "0.9.0"
 anyhow = { version = "1.0.66", features = ["std", "backtrace"] }
 libc = { version = "0.2.137" }
-bytes = { version = "1.3.0" }
 serde_json = "1.0.92"
 futures = "0.3.25"
 

+ 7 - 2
crates/btfuse/src/fuse_daemon.rs

@@ -2,7 +2,7 @@
 use crate::{fuse_fs::FuseFs, PathExt, DEFAULT_CONFIG};
 
 use btfproto::server::FsProvider;
-use btlib::{BlockPath, Result};
+use btlib::{bterr, BlockPath, Result};
 use fuse_backend_rs::{
     api::server::Server,
     transport::{self, FuseChannel, FuseSession},
@@ -17,7 +17,7 @@ use std::{
     path::Path,
     sync::Arc,
 };
-use tokio::task::JoinSet;
+use tokio::{sync::oneshot, task::JoinSet};
 
 pub use private::FuseDaemon;
 
@@ -51,10 +51,15 @@ mod private {
             mnt_path: PathBuf,
             num_threads: usize,
             fallback_path: Arc<BlockPath>,
+            mounted_signal: Option<oneshot::Sender<()>>,
             provider: P,
         ) -> Result<Self> {
             let server = Arc::new(Server::new(FuseFs::new(provider, fallback_path)));
             let session = Self::session(mnt_path)?;
+            if let Some(tx) = mounted_signal {
+                tx.send(())
+                    .map_err(|_| bterr!("failed to send mounted signal"))?;
+            }
             let mut set = JoinSet::new();
             for _ in 0..num_threads {
                 let server = server.clone();

+ 25 - 17
crates/btfuse/src/fuse_fs.rs

@@ -5,7 +5,6 @@ use btlib::{
     BlockPath, Epoch, Result,
 };
 use btserde::read_from;
-use bytes::{Buf, BytesMut};
 use core::{ffi::CStr, future::Future, time::Duration};
 use fuse_backend_rs::{
     abi::fuse_abi::{stat64, Attr, CreateIn},
@@ -135,16 +134,15 @@ mod private {
                 flags: FlagValue::ReadOnly.into(),
             };
             let OpenReply { handle, .. } = provider.open(from, msg).await?;
-            let mut buf = BytesMut::new();
             let msg = Read {
                 inode: parent,
                 handle,
                 offset: 0,
                 size: u64::MAX,
             };
-            provider.read(from, msg, |data| buf.extend_from_slice(data))?;
-            let mut reader = buf.reader();
-            read_from(&mut reader).map_err(|err| err.into())
+            let guard = provider.read(from, msg).await?;
+            let mut slice: &[u8] = &guard;
+            read_from(&mut slice).map_err(|err| err.into())
         }
     }
 
@@ -279,14 +277,17 @@ mod private {
             _lock_owner: Option<u64>,
             _flags: u32,
         ) -> IoResult<usize> {
-            let path = self.path_from_luid(ctx.uid);
-            let msg = Read {
-                inode,
-                handle,
-                offset,
-                size: size as u64,
-            };
-            self.provider.read(path, msg, |data| w.write(data))?
+            block_on(async move {
+                let path = self.path_from_luid(ctx.uid);
+                let msg = Read {
+                    inode,
+                    handle,
+                    offset,
+                    size: size as u64,
+                };
+                let guard = self.provider.read(path, msg).await?;
+                w.write(&guard)
+            })
         }
 
         fn write(
@@ -305,10 +306,17 @@ mod private {
             block_on(async move {
                 let path = self.path_from_luid(ctx.uid);
                 let size: usize = size.try_into().display_err()?;
-                let WriteReply { written, .. } = self
-                    .provider
-                    .write(path, inode, handle, offset, size as u64, r)
-                    .await?;
+                // TODO: Eliminate this copying, or at least use a pool of buffers to avoid
+                // allocating on every write. We could pass `r` to the provider if it were Send.
+                let mut buf = Vec::with_capacity(size);
+                r.read_to_end(&mut buf)?;
+                let msg = Write {
+                    inode,
+                    handle,
+                    offset,
+                    data: buf.as_slice(),
+                };
+                let WriteReply { written, .. } = self.provider.write(path, msg).await?;
                 Ok(written.try_into().display_err()?)
             })
         }

+ 42 - 18
crates/btfuse/src/main.rs

@@ -15,13 +15,13 @@ use btlib::{
     },
     Result,
 };
-use core::future::Future;
 use std::{
     fs::{self},
     io,
     path::{Path, PathBuf},
     sync::Arc,
 };
+use tokio::sync::oneshot;
 
 const ENVVARS: EnvVars = EnvVars {
     tabrmd: "BT_TABRMD",
@@ -58,20 +58,20 @@ fn node_creds(state_file: PathBuf, tabrmd_cfg: &str) -> Result<TpmCreds> {
     cred_store.node_creds()
 }
 
-fn provider<C: 'static + Creds + Send + Sync>(
+async fn provider<C: 'static + Creds + Send + Sync>(
     btdir: PathBuf,
     node_creds: C,
 ) -> Result<impl FsProvider> {
     btdir.try_create_dir()?;
     let empty = fs::read_dir(&btdir)?.next().is_none();
     if empty {
-        LocalFs::new_empty(btdir, 0, node_creds, btfproto::local_fs::ModeAuthorizer {})
+        LocalFs::new_empty(btdir, 0, node_creds, btfproto::local_fs::ModeAuthorizer {}).await
     } else {
         LocalFs::new_existing(btdir, node_creds, btfproto::local_fs::ModeAuthorizer {})
     }
 }
 
-fn run_daemon(config: Config) -> impl Send + Sync + Future<Output = ()> {
+async fn run_daemon(config: Config, mounted_signal: Option<oneshot::Sender<()>>) {
     let node_creds =
         node_creds(config.tpm_state_file, &config.tabrmd).expect("failed to get node creds");
     let fallback_path = {
@@ -81,11 +81,19 @@ fn run_daemon(config: Config) -> impl Send + Sync + Future<Output = ()> {
             .unwrap();
         Arc::new(writecap.bind_path())
     };
-    let provider = provider(config.block_dir, node_creds).expect("failed to create FS provider");
-
-    let mut daemon = FuseDaemon::new(config.mnt_dir, config.threads, fallback_path, provider)
-        .expect("failed to create FUSE daemon");
-    async move { daemon.finished().await }
+    let provider = provider(config.block_dir, node_creds)
+        .await
+        .expect("failed to create FS provider");
+
+    let mut daemon = FuseDaemon::new(
+        config.mnt_dir,
+        config.threads,
+        fallback_path,
+        mounted_signal,
+        provider,
+    )
+    .expect("failed to create FUSE daemon");
+    daemon.finished().await
 }
 
 #[tokio::main]
@@ -98,7 +106,7 @@ async fn main() {
         .with_tabrmd(from_envvar(ENVVARS.tabrmd).unwrap())
         .with_mnt_options(from_envvar(ENVVARS.mnt_options).unwrap());
     let config = builder.build();
-    run_daemon(config).await;
+    run_daemon(config, None).await;
 }
 
 #[cfg(test)]
@@ -114,12 +122,12 @@ mod test {
             set_permissions, write, Permissions, ReadDir,
         },
         os::unix::fs::PermissionsExt,
-        sync::mpsc,
         thread::JoinHandle,
         time::Duration,
     };
     use swtpm_harness::SwtpmHarness;
     use tempdir::TempDir;
+    use tokio::sync::oneshot::error::TryRecvError;
 
     /// An optional timeout to wait for the FUSE daemon to start in tests.
     const TIMEOUT: Option<Duration> = Some(Duration::from_millis(1000));
@@ -172,7 +180,7 @@ mod test {
     impl TestCase {
         fn new() -> TestCase {
             let tmp = TempDir::new("btfuse").unwrap();
-            let (mounted_tx, mounted_rx) = mpsc::channel();
+            let (mounted_tx, mut mounted_rx) = oneshot::channel();
             let (swtpm, cred_store) = Self::swtpm();
             let config = Config::builder()
                 .with_block_dir(Some(tmp.path().join("bt")))
@@ -183,8 +191,26 @@ mod test {
             let config_clone = config.clone();
             let handle = std::thread::spawn(move || Self::run(mounted_tx, config_clone));
             match TIMEOUT {
-                Some(duration) => mounted_rx.recv_timeout(duration).unwrap(),
-                None => mounted_rx.recv().unwrap(),
+                Some(duration) => {
+                    let deadline = std::time::Instant::now() + duration;
+                    loop {
+                        if std::time::Instant::now() > deadline {
+                            panic!("timed out waiting for the mounted signal");
+                        }
+                        match mounted_rx.try_recv() {
+                            Ok(_) => break,
+                            Err(err) => match err {
+                                TryRecvError::Empty => {
+                                    std::thread::sleep(Duration::from_millis(10));
+                                }
+                                TryRecvError::Closed => {
+                                    panic!("channel was closed before mounted signal was sent")
+                                }
+                            },
+                        }
+                    }
+                }
+                None => mounted_rx.blocking_recv().unwrap(),
             };
             let node_principal =
                 OsString::from(cred_store.node_creds().unwrap().principal().to_string());
@@ -198,17 +224,15 @@ mod test {
             }
         }
 
-        fn run(mounted_tx: mpsc::Sender<()>, config: Config) {
+        fn run(mounted_tx: oneshot::Sender<()>, config: Config) {
             let runtime = tokio::runtime::Builder::new_current_thread()
                 .build()
                 .unwrap();
             // run_daemon can only be called in the context of a runtime, hence the need to call
             // spawn_blocking.
-            let started = runtime.spawn_blocking(|| run_daemon(config));
+            let started = runtime.spawn_blocking(move || run_daemon(config, Some(mounted_tx)));
             let future = runtime.block_on(started).unwrap();
 
-            // The file system is mounted before run_daemon returns.
-            mounted_tx.send(()).unwrap();
             runtime.block_on(future);
         }
 

+ 63 - 0
crates/btlib/src/drop_trigger.rs

@@ -0,0 +1,63 @@
+pub struct DropTrigger<F: FnOnce()> {
+    trigger: Option<F>,
+}
+
+impl<F: FnOnce()> DropTrigger<F> {
+    pub fn new(trigger: F) -> Self {
+        Self {
+            trigger: Some(trigger),
+        }
+    }
+
+    pub fn disarm(&mut self) -> bool {
+        self.trigger.take().is_some()
+    }
+}
+
+impl<F: FnOnce()> Drop for DropTrigger<F> {
+    fn drop(&mut self) {
+        if let Some(trigger) = self.trigger.take() {
+            trigger()
+        }
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn trigger_called_on_drop() {
+        let mut x = 0;
+        let dt = DropTrigger::new(|| x += 1);
+
+        drop(dt);
+
+        assert_eq!(1, x);
+    }
+
+    #[test]
+    fn trigger_not_called_when_disarmed() {
+        let mut x = 0;
+        let mut dt = DropTrigger::new(|| x += 1);
+
+        let was_armed = dt.disarm();
+        drop(dt);
+
+        assert!(was_armed);
+        assert_eq!(0, x);
+    }
+
+    #[test]
+    fn disarm_returns_false_if_was_not_armed() {
+        let mut x = 0;
+        let mut dt = DropTrigger::new(|| x += 1);
+
+        let first = dt.disarm();
+        let second = dt.disarm();
+        drop(dt);
+
+        assert!(first);
+        assert!(!second);
+    }
+}

+ 1 - 0
crates/btlib/src/lib.rs

@@ -6,6 +6,7 @@ pub mod collections;
 pub mod config_helpers;
 /// Code which enables cryptographic operations.
 pub mod crypto;
+pub mod drop_trigger;
 pub mod error;
 pub mod log;
 pub mod sectored_buf;

Some files were not shown because too many files changed in this diff