Explorar o código

Added a trait to allow MerkleStream to pass its integrity information
to its inner stream.

Matthew Carr %!s(int64=2) %!d(string=hai) anos
pai
achega
7ee175b9ef
Modificáronse 3 ficheiros con 210 adicións e 59 borrados
  1. 25 15
      crates/btlib/src/crypto/mod.rs
  2. 158 29
      crates/btlib/src/lib.rs
  3. 27 15
      crates/btlib/src/test_helpers.rs

+ 25 - 15
crates/btlib/src/crypto/mod.rs

@@ -1,7 +1,7 @@
 /// Functions for performing cryptographic operations on the main data structures.
 mod tpm;
 
-use crate::{Decompose, SectoredBuf, SECTOR_SZ_DEFAULT};
+use crate::{BoxInIoErr, Decompose, IntegrityWrite, SectoredBuf, SECTOR_SZ_DEFAULT};
 
 use super::{
     fmt, io, BigArray, Block, Deserialize, Display, Epoch, Formatter, Hashable, Header, Owned,
@@ -1211,6 +1211,9 @@ pub trait MerkleNode: Default + Serialize + for<'de> Deserialize<'de> {
         right: Option<&'a Self>,
     ) -> Result<()>;
 
+    /// Attempts to borrow the data in this node as a slice.
+    fn try_as_slice(&self) -> Result<&[u8]>;
+
     /// Computes the hash of the data produced by the given iterator and writes it to the
     /// given slice.
     fn digest<'a, I: Iterator<Item = &'a [u8]>>(dest: &mut [u8], parts: I) -> Result<()> {
@@ -1329,6 +1332,13 @@ impl MerkleNode for Sha2_256 {
             Err(Error::HashCmpFailure)
         }
     }
+
+    fn try_as_slice(&self) -> Result<&[u8]> {
+        self.0
+            .as_ref()
+            .map(|arr| arr.as_slice())
+            .ok_or_else(|| Error::custom("this merkle node is empty"))
+    }
 }
 
 /// An index into a binary tree. This type provides convenience methods for navigating a tree.
@@ -1623,17 +1633,20 @@ impl<T, H> Decompose<T> for MerkleStream<T, H> {
     }
 }
 
-impl<T: Write, H: MerkleNode> Write for MerkleStream<T, H> {
+impl<T: IntegrityWrite, H: MerkleNode> Write for MerkleStream<T, H> {
     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
         self.assert_sector_sz(buf.len())?;
         self.tree.write(self.offset, buf)?;
-        let written = self.inner.write(buf)?;
+        // Safety: We know the root node exists and is non-empty because we just wrote data into
+        // the tree.
+        let root = self.tree.nodes.first().unwrap();
+        let written = self.inner.integrity_write(buf, root.try_as_slice()?)?;
         self.offset += self.sector_sz();
         Ok(written)
     }
 
     fn flush(&mut self) -> io::Result<()> {
-        self.inner.flush()
+        Ok(())
     }
 }
 
@@ -1650,7 +1663,7 @@ impl<T: Read, H: MerkleNode> Read for MerkleStream<T, H> {
 impl<T: Seek, H> Seek for MerkleStream<T, H> {
     fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> {
         let from_start = self.inner.seek(pos)?;
-        self.offset = from_start as usize;
+        self.offset = from_start.try_into().box_err()?;
         Ok(from_start)
     }
 }
@@ -1794,7 +1807,7 @@ impl<T: Seek> Seek for SecretStream<T> {
                 if offset >= 0 {
                     outer_offset + offset as u64
                 } else {
-                    outer_offset - (-offset) as u64
+                    outer_offset - (-offset as u64)
                 }
             }
             SeekFrom::End(_) => {
@@ -1812,7 +1825,7 @@ impl<T: Seek> Seek for SecretStream<T> {
     }
 }
 
-pub(crate) fn unveil<T: Read + Write + Seek, C: Encrypter + Decrypter, D>(
+pub(crate) fn unveil<T: Read + IntegrityWrite + Seek, C: Encrypter + Decrypter, D>(
     block: Block<T, D>,
     principal: &Principal,
     creds: &C,
@@ -1965,12 +1978,9 @@ pub fn verify_header(_header: &Header, _sig: &Signature) -> Result<()> {
 mod tests {
     use super::*;
     use crate::{test_helpers::*, BrotliParams, SectoredBuf, SECTOR_SZ_DEFAULT};
-    use std::{
-        io::{Cursor, SeekFrom},
-        time::Duration,
-    };
+    use std::{io::SeekFrom, time::Duration};
 
-    fn encrypt_decrypt_block_test_case<T: Read + Write + Seek, C: Creds, D>(
+    fn encrypt_decrypt_block_test_case<T: Read + IntegrityWrite + Seek, C: Creds, D>(
         mut block: Block<T, D>,
         principal: &Principal,
         creds: &C,
@@ -2411,7 +2421,7 @@ mod tests {
 
     fn merkle_stream_sequential_test_case(sect_sz: usize, sect_count: usize) {
         let mut stream = MerkleStream::new(MerkleTree::<Sha2_256>::empty(sect_sz))
-            .try_compose(Cursor::new(vec![0u8; sect_count * sect_sz]))
+            .try_compose(BtCursor::new(vec![0u8; sect_count * sect_sz]))
             .expect("compose failed");
         for k in 1..(sect_count + 1) {
             let sector = vec![k as u8; sect_sz];
@@ -2437,7 +2447,7 @@ mod tests {
 
     fn merkle_stream_random_test_case(rando: Randomizer, sect_sz: usize, sect_ct: usize) {
         let mut stream = MerkleStream::new(MerkleTree::<Sha2_256>::empty(sect_sz))
-            .try_compose(Cursor::new(vec![0u8; sect_sz * sect_ct]))
+            .try_compose(BtCursor::new(vec![0u8; sect_sz * sect_ct]))
             .expect("compose failed");
         let indices: Vec<usize> = rando.take(sect_ct).map(|e| e % sect_ct).collect();
         for index in indices.iter().map(|e| *e) {
@@ -2476,7 +2486,7 @@ mod tests {
     fn compose_merkle_and_secret_streams() {
         const SECT_SZ: usize = 4096;
         const SECT_CT: usize = 16;
-        let memory = Cursor::new([0u8; SECT_SZ * SECT_CT]);
+        let memory = BtCursor::new([0u8; SECT_SZ * SECT_CT]);
         let merkle = MerkleStream::new(MerkleTree::<Sha2_256>::empty(SECT_SZ))
             .try_compose(memory)
             .expect("compose for merkle failed");

+ 158 - 29
crates/btlib/src/lib.rs

@@ -18,7 +18,7 @@ extern crate lazy_static;
 use brotli::{CompressorWriter, Decompressor};
 use btserde::{self, read_from, write_to};
 mod crypto;
-use crypto::{AsymKeyPub, Ciphertext, Hash, HashKind, Sign, Signature, SymKey};
+use crypto::{AsymKeyPub, Ciphertext, Hash, HashKind, Sign, Signature, Signer, SymKey};
 
 use log::error;
 use serde::{Deserialize, Serialize};
@@ -31,7 +31,7 @@ use std::{
     hash::Hash as Hashable,
     io::{self, Read, Seek, SeekFrom, Write},
     ops::{Add, Sub},
-    sync::{Arc, RwLock},
+    sync::{Arc, PoisonError, RwLock},
     time::{Duration, SystemTime},
 };
 
@@ -96,8 +96,36 @@ impl From<std::num::TryFromIntError> for Error {
     }
 }
 
+impl<T: std::fmt::Debug> From<PoisonError<T>> for Error {
+    fn from(err: PoisonError<T>) -> Self {
+        Error::custom(err.to_string())
+    }
+}
+
 type Result<T> = std::result::Result<T, Error>;
 
+/// TODO: Remove this once the error_chain crate is integrated.
+trait BoxInIoErr<T> {
+    fn box_err(self) -> std::result::Result<T, io::Error>;
+}
+
+impl<T, E: std::error::Error + Send + Sync + 'static> BoxInIoErr<T> for std::result::Result<T, E> {
+    fn box_err(self) -> std::result::Result<T, io::Error> {
+        self.map_err(|err| io::Error::new(io::ErrorKind::Other, err))
+    }
+}
+
+/// TODO: Remove this once the error_chain crate is integrated.
+trait StrInIoErr<T> {
+    fn str_err(self) -> std::result::Result<T, io::Error>;
+}
+
+impl<T, E: Display> StrInIoErr<T> for std::result::Result<T, E> {
+    fn str_err(self) -> std::result::Result<T, io::Error> {
+        self.map_err(|err| io::Error::new(io::ErrorKind::Other, err.to_string()))
+    }
+}
+
 /// A Block tagged with its version number. When a block of a previous version is received over
 /// the network or read from the filesystem, it is upgraded to the current version before being
 /// processed.
@@ -123,26 +151,114 @@ trait Sectored {
     }
 }
 
+/// A version of the `Write` trait, which allows integrity information to be supplied when writing.
+trait IntegrityWrite {
+    fn integrity_write(&mut self, buf: &[u8], integrity: &[u8]) -> io::Result<usize>;
+}
+
 #[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
 pub struct Header {
     path: Path,
     readcaps: HashMap<Principal, Ciphertext<SymKey>>,
     writecap: Writecap,
-    merkle_root: Hash,
+    /// A hash which provides integrity for the contents of the block body.
+    integrity: Hash,
 }
 
 #[derive(Debug, PartialEq, Serialize, Deserialize)]
 struct BlockShared<C> {
     header: Header,
     sig: Signature,
+    #[serde(skip)]
     creds: C,
 }
 
 struct BlockStream<T, C> {
     shared: Arc<RwLock<BlockShared<C>>>,
+    body_len: u64,
+    header_buf: Vec<u8>,
     inner: T,
 }
 
+impl<T, C> BlockStream<T, C> {
+    fn new(shared: BlockShared<C>, inner: T, body_len: u64) -> BlockStream<T, C> {
+        BlockStream {
+            shared: Arc::new(RwLock::new(shared)),
+            inner,
+            header_buf: Vec::new(),
+            body_len,
+        }
+    }
+}
+
+impl<T: Seek + Write, C: std::fmt::Debug + Signer> BlockStream<T, C> {
+    fn write_trailer(&mut self, integrity: &[u8]) -> Result<()> {
+        let pos = self.inner.stream_position()?;
+        self.body_len = self.body_len.max(pos);
+        self.inner.seek(SeekFrom::Start(self.body_len))?;
+        {
+            let mut shared = self.shared.write()?;
+            shared.header.integrity.as_mut().copy_from_slice(integrity);
+            self.header_buf.clear();
+            write_to(&shared.header, &mut self.header_buf)?;
+            shared.sig = shared
+                .creds
+                .sign(std::iter::once(self.header_buf.as_slice()))?;
+
+            self.inner.write_all(&self.header_buf)?;
+            write_to(&shared.sig, &mut self.inner)?;
+        }
+        let end: i64 = (self.inner.stream_position()? + 8).try_into()?;
+        let body_len: i64 = self.body_len.try_into()?;
+        let offset = end - body_len;
+        write_to(&offset, &mut self.inner)?;
+        self.inner.seek(SeekFrom::Start(pos))?;
+        Ok(())
+    }
+}
+
+impl<T: Write + Seek, C: std::fmt::Debug + Signer> IntegrityWrite for BlockStream<T, C> {
+    fn integrity_write(&mut self, buf: &[u8], integrity: &[u8]) -> io::Result<usize> {
+        let written = self.inner.write(buf)?;
+        if written > 0 {
+            let result = self.write_trailer(integrity);
+            if let Err(err) = result {
+                error!("error occurred while writing block trailer: {}", err);
+            }
+        }
+        Ok(written)
+    }
+}
+
+impl<T: Read, C> Read for BlockStream<T, C> {
+    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
+        self.inner.read(buf)
+    }
+}
+
+/// Adds a signed integer to an unsigned integer and returns the result.
+fn add_signed(unsigned: u64, signed: i64) -> u64 {
+    if signed >= 0 {
+        unsigned + signed as u64
+    } else {
+        unsigned - (-signed as u64)
+    }
+}
+
+impl<T: Seek, C> Seek for BlockStream<T, C> {
+    /// Seeks to the given position in the stream. If a position beyond the end of the stream is
+    /// specified, the the seek will be to the end of the stream.
+    fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
+        let from_start = match pos {
+            SeekFrom::Start(from_start) => from_start,
+            SeekFrom::Current(from_curr) => add_signed(self.inner.stream_position()?, from_curr),
+            SeekFrom::End(from_end) => add_signed(self.body_len, from_end),
+        };
+        self.inner
+            .seek(SeekFrom::Start(from_start.min(self.body_len)))
+    }
+}
+
 /// A container which binds together ciphertext along with the metadata needed to identify,
 /// verify and decrypt it.
 #[derive(Debug)]
@@ -170,16 +286,36 @@ impl<T, C> Block<T, C> {
     }
 }
 
-impl<C> Block<File, C> {
-    fn new<P: AsRef<std::path::Path>>(creds: C, path: P) -> Result<Block<File, C>> {
-        let mut file = OpenOptions::new().read(true).write(true).open(path)?;
-        let header: Header = read_from(&mut file)?;
-        let sig: Signature = read_from(&mut file)?;
+impl<T: Read + Seek, C> Block<T, C> {
+    fn with_body(body: BlockStream<T, C>) -> Block<BlockStream<T, C>, C> {
+        Block {
+            shared: body.shared.clone(),
+            body,
+        }
+    }
+
+    fn new(mut inner: T, creds: C) -> Result<Block<BlockStream<T, C>, C>> {
+        // TODO: What if the inner stream is empty?
+        inner.seek(SeekFrom::End(-8))?;
+        let offset: i64 = read_from(&mut inner)?;
+        let body_len = inner.seek(SeekFrom::Current(offset))?;
+        let header: Header = read_from(&mut inner)?;
+        let sig: Signature = read_from(&mut inner)?;
         crypto::verify_header(&header, &sig)?;
-        Ok(Block {
-            shared: Arc::new(RwLock::new(BlockShared { header, sig, creds })),
-            body: file,
-        })
+        inner.seek(SeekFrom::Start(0))?;
+        let shared = BlockShared { header, sig, creds };
+        let body = BlockStream::new(shared, inner, body_len);
+        Ok(Block::with_body(body))
+    }
+}
+
+impl<C> Block<File, C> {
+    fn from_path<P: AsRef<std::path::Path>>(
+        creds: C,
+        path: P,
+    ) -> Result<Block<BlockStream<File, C>, C>> {
+        let inner = OpenOptions::new().read(true).write(true).open(path)?;
+        Block::new(inner, creds)
     }
 }
 
@@ -307,13 +443,6 @@ impl<T: Read> TryCompose<T, Decompressor<T>> for BrotliParams {
     }
 }
 
-/// TODO: Remove this once the error_chain crate is integrated.
-fn err_conv<T, E: std::error::Error + Send + Sync + 'static>(
-    result: std::result::Result<T, E>,
-) -> std::result::Result<T, io::Error> {
-    result.map_err(|err| io::Error::new(io::ErrorKind::Other, err))
-}
-
 /// A stream which buffers writes and read such that the inner stream only sees reads and writes
 /// of sector length buffers.
 struct SectoredBuf<T> {
@@ -385,7 +514,7 @@ impl<T: Read + Seek> SectoredBuf<T> {
     /// Fills the internal buffer by reading from the inner stream at the current position
     /// and updates `self.buf_start` with the position read from.
     fn fill_internal_buf(&mut self) -> io::Result<usize> {
-        self.buf_start = err_conv(self.inner.stream_position()?.try_into())?;
+        self.buf_start = self.inner.stream_position()?.try_into().box_err()?;
         let read_bytes = if self.buf_start < self.len {
             let read_bytes = self.inner.fill_buf(&mut self.buf)?;
             if read_bytes < self.buf.len() {
@@ -489,9 +618,9 @@ impl<T: Seek + Read + Write> Write for SectoredBuf<T> {
         }
 
         // Write out the contents of the buffer.
-        let sect_sz: u64 = err_conv(self.sector_sz().try_into())?;
+        let sect_sz: u64 = self.sector_sz().try_into().box_err()?;
         let inner_pos = self.inner.stream_position()?;
-        let inner_pos_usize: usize = err_conv(inner_pos.try_into())?;
+        let inner_pos_usize: usize = inner_pos.try_into().box_err()?;
         let is_new_sector = self.pos > inner_pos_usize;
         let is_full = (self.buf.len() - self.buf_pos()) == 0;
         let seek_to = if is_new_sector {
@@ -503,7 +632,7 @@ impl<T: Seek + Read + Write> Write for SectoredBuf<T> {
         } else {
             // The contents of the buffer were previously read from inner, so we write the
             // updated contents to the same offset.
-            let sect_start: u64 = err_conv(self.buf_start.try_into())?;
+            let sect_start: u64 = self.buf_start.try_into().box_err()?;
             self.inner.seek(SeekFrom::Start(sect_start))?;
             if is_full {
                 inner_pos
@@ -517,8 +646,8 @@ impl<T: Seek + Read + Write> Write for SectoredBuf<T> {
         self.len = self.len.max(self.pos);
         self.inner.seek(SeekFrom::Start(0))?;
         self.fill_internal_buf()?;
-        let len: u64 = err_conv(self.len.try_into())?;
-        err_conv(write_to(&len, &mut self.buf.as_mut_slice()))?;
+        let len: u64 = self.len.try_into().box_err()?;
+        write_to(&len, &mut self.buf.as_mut_slice()).box_err()?;
         self.inner.seek(SeekFrom::Start(0))?;
         self.inner.write_all(&self.buf)?;
 
@@ -591,15 +720,15 @@ impl<T: Seek + Read + Write> Seek for SectoredBuf<T> {
         };
         let sect_sz = self.sector_sz();
         let sect_index = self.buf_sector_index();
-        let sect_index_new = err_conv(TryInto::<usize>::try_into(inner_pos_new))? / sect_sz;
-        let pos: u64 = err_conv(self.pos.try_into())?;
+        let sect_index_new = TryInto::<usize>::try_into(inner_pos_new).box_err()? / sect_sz;
+        let pos: u64 = self.pos.try_into().box_err()?;
         if sect_index != sect_index_new || pos == inner_pos {
             self.flush()?;
-            let seek_to: u64 = err_conv((sect_index_new * sect_sz).try_into())?;
+            let seek_to: u64 = (sect_index_new * sect_sz).try_into().box_err()?;
             self.inner.seek(SeekFrom::Start(seek_to))?;
             self.fill_internal_buf()?;
         }
-        self.pos = err_conv(inner_pos_new.try_into())?;
+        self.pos = inner_pos_new.try_into().box_err()?;
         Ok(Self::self_pos(inner_pos_new))
     }
 }

+ 27 - 15
crates/btlib/src/test_helpers.rs

@@ -171,11 +171,11 @@ pub(crate) fn make_readcap_for<C: Encrypter + Owned>(creds: &C) -> Readcap {
     }
 }
 
-pub(crate) fn make_block() -> Block<SerdeCursor<Vec<u8>>, impl Creds> {
+pub(crate) fn make_block() -> Block<BtCursor<Vec<u8>>, impl Creds> {
     make_block_with(make_readcap())
 }
 
-pub(crate) fn make_block_with(readcap: Readcap) -> Block<SerdeCursor<Vec<u8>>, impl Creds> {
+pub(crate) fn make_block_with(readcap: Readcap) -> Block<BtCursor<Vec<u8>>, impl Creds> {
     let mut readcaps = HashMap::new();
     readcaps.insert(readcap.issued_to, readcap.key);
     // Notice that the writecap path contains the block path. If this were not the case, the block
@@ -188,12 +188,12 @@ pub(crate) fn make_block_with(readcap: Readcap) -> Block<SerdeCursor<Vec<u8>>, i
                 path: make_path_with_owner(root_writecap.issued_to.clone(), vec!["apps", "verse"]),
                 readcaps,
                 writecap,
-                merkle_root: Hash::Sha2_256([0u8; HashKind::Sha2_256.len()]),
+                integrity: Hash::Sha2_256([0u8; HashKind::Sha2_256.len()]),
             },
             sig: Signature::copy_from(Sign::RSA_PSS_3072_SHA_256, &SIGNATURE),
             creds,
         })),
-        body: SerdeCursor::new(Vec::new()),
+        body: BtCursor::new(Vec::new()),
     }
 }
 
@@ -367,22 +367,22 @@ mod serde_cursor {
     }
 }
 
-/// A wrapper for `Cursor<T>` which implements `Serialize` and `Deserialize<'de>` for any `'de'.
+/// A wrapper for `Cursor<T>` which implements additional traits.
 #[derive(Debug, PartialEq, Serialize, Deserialize)]
-pub struct SerdeCursor<T: FromVec> {
+pub struct BtCursor<T: FromVec> {
     #[serde(with = "serde_cursor")]
     cursor: RefCell<Cursor<T>>,
 }
 
-impl<T: FromVec> SerdeCursor<T> {
-    fn new(inner: T) -> SerdeCursor<T> {
-        SerdeCursor {
+impl<T: FromVec> BtCursor<T> {
+    pub(crate) fn new(inner: T) -> BtCursor<T> {
+        BtCursor {
             cursor: RefCell::new(Cursor::new(inner)),
         }
     }
 }
 
-impl Write for SerdeCursor<Vec<u8>> {
+impl Write for BtCursor<Vec<u8>> {
     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
         self.cursor.get_mut().write(buf)
     }
@@ -392,7 +392,7 @@ impl Write for SerdeCursor<Vec<u8>> {
     }
 }
 
-impl<const N: usize> Write for SerdeCursor<[u8; N]> {
+impl<const N: usize> Write for BtCursor<[u8; N]> {
     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
         self.cursor.get_mut().write(buf)
     }
@@ -402,28 +402,40 @@ impl<const N: usize> Write for SerdeCursor<[u8; N]> {
     }
 }
 
-impl<T: FromVec> Read for SerdeCursor<T> {
+impl<T: FromVec> Read for BtCursor<T> {
     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
         self.cursor.get_mut().read(buf)
     }
 }
 
-impl<T: FromVec> Seek for SerdeCursor<T> {
+impl<T: FromVec> Seek for BtCursor<T> {
     fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
         self.cursor.get_mut().seek(pos)
     }
 }
 
+impl IntegrityWrite for BtCursor<Vec<u8>> {
+    fn integrity_write(&mut self, buf: &[u8], _integrity: &[u8]) -> io::Result<usize> {
+        self.cursor.get_mut().write(buf)
+    }
+}
+
+impl<const N: usize> IntegrityWrite for BtCursor<[u8; N]> {
+    fn integrity_write(&mut self, buf: &[u8], integrity: &[u8]) -> io::Result<usize> {
+        self.cursor.get_mut().write(buf)
+    }
+}
+
 #[derive(Debug, PartialEq, Serialize, Deserialize)]
 pub struct SectoredCursor<T: FromVec> {
-    cursor: SerdeCursor<T>,
+    cursor: BtCursor<T>,
     sect_sz: usize,
 }
 
 impl<T: FromVec> SectoredCursor<T> {
     pub fn new(inner: T, sect_sz: usize) -> SectoredCursor<T> {
         SectoredCursor {
-            cursor: SerdeCursor::new(inner),
+            cursor: BtCursor::new(inner),
             sect_sz,
         }
     }