Ver Fonte

Implemented the `MerkleStream`.

Matthew Carr há 2 anos atrás
pai
commit
1bf444ee2c
3 ficheiros alterados com 418 adições e 145 exclusões
  1. 1 0
      crates/.vscode/settings.json
  2. 347 144
      crates/btnode/src/crypto/mod.rs
  3. 70 1
      crates/btnode/src/test_helpers.rs

+ 1 - 0
crates/.vscode/settings.json

@@ -11,6 +11,7 @@
         "encrypter",
         "Encryptor",
         "Hashable",
+        "Merkle",
         "newtype",
         "PKCS",
         "pkey",

+ 347 - 144
crates/btnode/src/crypto/mod.rs

@@ -24,7 +24,6 @@ use super::{
     Header,
     Compose,
     Sectored,
-    SECTOR_SZ_DEFAULT,
 };
 
 use openssl::{
@@ -71,6 +70,7 @@ pub enum Error {
     IncorrectSize { expected: usize, actual: usize },
     IndexOutOfBounds { index: usize, limit: usize },
     IndivisibleSize { divisor: usize, actual: usize },
+    InvalidOffset { actual: usize, limit: usize },
     HashCmpFailure,
     RootHashNotVerified,
     WritecapAuthzErr(WritecapAuthzErr),
@@ -108,6 +108,13 @@ impl Display for Error {
                 => write!(
                     f, "expected a size which is divisible by {} but got {}", divisor, actual
                 ),
+            Error::InvalidOffset { actual, limit }
+                => write!(
+                    f,
+                    "offset {} is out of bounds, it must be strictly less than {}",
+                    actual,
+                    limit
+            ),
             Error::HashCmpFailure => write!(f, "hash data are not equal"),
             Error::RootHashNotVerified => write!(f, "root hash is not verified"),
             Error::WritecapAuthzErr(err) => err.fmt(f),
@@ -176,7 +183,7 @@ fn rand_vec(len: usize) -> Result<Vec<u8>> {
 #[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Hashable, Clone, EnumDiscriminants)]
 #[strum_discriminants(derive(EnumString, Display, Serialize, Deserialize))]
 #[strum_discriminants(name(HashKind))]
-pub(crate) enum Hash {
+pub enum Hash {
     Sha2_256([u8; HashKind::Sha2_256.len()]),
     #[serde(with = "BigArray")]
     Sha2_512([u8; HashKind::Sha2_512.len()]),
@@ -189,12 +196,25 @@ impl Default for HashKind {
 }
 
 impl HashKind {
-    const fn len(self) -> usize {
+    pub const fn len(self) -> usize {
         match self {
             HashKind::Sha2_256 => 32,
             HashKind::Sha2_512 => 64,
         }
     }
+
+    pub fn digest<'a, I: Iterator<Item = &'a [u8]>>(self, dest: &mut [u8], parts: I) -> Result<()> {
+        if dest.len() != self.len() {
+            return Err(Error::IncorrectSize { expected: self.len(), actual: dest.len() })
+        }
+        let mut hasher = Hasher::new(self.into())?;
+        for part in parts {
+            hasher.update(part)?;
+        }
+        let hash = hasher.finish()?;
+        dest.copy_from_slice(&hash);
+        Ok(())
+    }
 }
 
 impl From<HashKind> for MessageDigest {
@@ -1110,14 +1130,15 @@ pub(crate) trait CredStore {
     ) -> Result<Self::CredHandle>;
 }
 
-/// Returns the base 2 logarithm of the given number. This function will return 0 when given 0.
-fn log2(mut n: usize) -> usize {
+/// Returns the base 2 logarithm of the given number. This function will return -1 when given 0, and
+/// this is the only input for which a negative value is returned.
+fn log2(mut n: usize) -> isize {
     // Is there a better implementation of this in std? I wasn't able to find an integer log2
     // function in std, so I wrote this naive implementation.
     if 0 == n {
-        return 0;
+        return -1;
     }
-    let num_bits = 8 * std::mem::size_of::<usize>(); 
+    let num_bits = 8 * std::mem::size_of::<usize>() as isize; 
     for k in 0..num_bits {
         n >>= 1;
         if 0 == n {
@@ -1127,20 +1148,35 @@ fn log2(mut n: usize) -> usize {
     num_bits
 }
 
-/// Returns 2^x.
-fn exp2(x: usize) -> usize {
-    1 << x
+/// Returns 2^x. Note that 0 is returned for any negative input.
+fn exp2(x: isize) -> usize {
+    if x < 0 {
+        0
+    }
+    else {
+        1 << x
+    }
+}
+
+trait SectoredExt : Sectored {
+    fn assert_sector_sz(&self, actual: usize) -> Result<()> {
+        let expected = self.sector_sz();
+        if expected == actual {
+            Ok(())
+        }
+        else {
+            Err(Error::IncorrectSize { expected, actual })
+        }
+    }
 }
 
+impl<T: Sectored> SectoredExt for T {}
+
 /// Trait for types which can be used as nodes in a `MerkleTree`.
 trait MerkleNode: Default + Serialize + for<'de> Deserialize<'de> {
     /// The kind of hash algorithm that this `HashData` uses.
     const KIND: HashKind;
 
-    /// Computes the hash of the data produced by the given iterator and writes it to the
-    /// given slice.
-    fn digest<'a, I: Iterator<Item = &'a [u8]>>(dest: &mut [u8], parts: I) -> Result<()>;
-
     /// Creates a new `HashData` instance by hashing the data produced by the given iterator and
     /// storing it in self.
     fn new<'a, I: Iterator<Item = &'a [u8]>>(parts: I) -> Result<Self>;
@@ -1163,6 +1199,13 @@ trait MerkleNode: Default + Serialize + for<'de> Deserialize<'de> {
     fn assert_parent_of<'a, I: Iterator<Item = &'a [u8]>>(
         &self, prefix: I, left: Option<&'a Self>, right: Option<&'a Self>
     ) -> Result<()>;
+
+
+    /// Computes the hash of the data produced by the given iterator and writes it to the
+    /// given slice.
+    fn digest<'a, I: Iterator<Item = &'a [u8]>>(dest: &mut [u8], parts: I) -> Result<()> {
+        Self::KIND.digest(dest, parts)
+    }
 }
 
 // TODO: Once full const generic support lands we can use a HashKind as a const param. Then we won't
@@ -1201,9 +1244,9 @@ impl Sha2_256 {
             (Some(left), Some(right))
                 => Self::digest(dest, prefix.chain([left, right].into_iter())),
             (Some(left), None)
-                => Self::digest(dest, prefix.chain([left].into_iter())),
+                => Self::digest(dest, prefix.chain([left, b"None"].into_iter())),
             (None, Some(right))
-                => Self::digest(dest, prefix.chain([right].into_iter())),
+                => Self::digest(dest, prefix.chain([b"None", right].into_iter())),
             (None, None) => when_neither(),
         }
     }
@@ -1212,19 +1255,6 @@ impl Sha2_256 {
 impl MerkleNode for Sha2_256 {
     const KIND: HashKind = HashKind::Sha2_256;
 
-    fn digest<'a, I: Iterator<Item = &'a [u8]>>(dest: &mut [u8], parts: I) -> Result<()> {
-        if dest.len() != Self::KIND.len() {
-            return Err(Error::IncorrectSize { expected: Self::KIND.len(), actual: dest.len() })
-        }
-        let mut hasher = Hasher::new(Self::KIND.into())?;
-        for part in parts {
-            hasher.update(part)?;
-        }
-        let hash = hasher.finish()?;
-        dest.copy_from_slice(&hash);
-        Ok(())
-    }
-
     fn new<'a, I: Iterator<Item = &'a [u8]>>(parts: I) -> Result<Self> {
         let mut array = [0u8; Self::KIND.len()];
         Self::digest(&mut array, parts)?;
@@ -1349,37 +1379,36 @@ struct MerkleTree<T> {
 }
 
 impl<T> MerkleTree<T> {
-    /// A byte to prefix data being hashed for leaf nodes.
-    const LEAF_PREFIX: &'static [u8] = &[0x00];
-    /// A byte to prefix data being hashed for interior nodes.
-    const INTERIOR_PREFIX: &'static [u8] = &[0x01];
+    /// A byte to prefix data being hashed for leaf nodes. It's important that this is different
+    /// from `INTERIOR_PREFIX`.
+    const LEAF_PREFIX: &'static [u8] = b"Leaf";
+    /// A byte to prefix data being hashed for interior nodes. It's important that this is different
+    /// from 'LEAF_PREFIX`.
+    const INTERIOR_PREFIX: &'static [u8] = b"Interior";
     
     /// Creates a new tree with no nodes in it and the given sector size.
     fn empty(sector_sz: usize) -> MerkleTree<T> {
         MerkleTree { nodes: Vec::new(), sector_sz, root_verified: true }
     }
 
-    /// Returns the number of generations in self.
-    fn generations(&self) -> usize {
+    /// Returns the number of generations in self. This method returns -1 when the tree is empty,
+    /// and this is the only case where a negative value is returned.
+    fn generations(&self) -> isize {
         log2(self.nodes.len())
     }
 
-    fn assert_sector_sz(&self, len: usize) -> Result<()> {
-        if self.sector_sz == len {
-            Ok(())
+    /// Returns the number of nodes in a complete binary tree with the given number of
+    /// generations. Note that `generations` is 0-based, so a tree with 1 node has 0 generations,
+    /// and a tree with 3 has 1.
+    fn len(generations: isize) -> usize {
+        if generations >= 0 {
+            exp2(generations + 1) - 1
         }
         else {
-            Err(Error::IncorrectSize { expected: self.sector_sz, actual: len })
+            0
         }
     }
 
-    /// Returns capacity needed to store a complete binary tree with the given number of
-    /// generations. Note that `generations` is 0-based, so a tree with 1 node has 0 generations,
-    /// and a tree with 3 has 1.
-    fn capacity(generations: usize) -> usize {
-        exp2(generations + 1) - 1
-    }
-
     /// Returns the index of the last node in the tree.
     fn end(&self) -> BinTreeIndex {
         BinTreeIndex(self.nodes.len() - 1)
@@ -1394,8 +1423,14 @@ impl<T> MerkleTree<T> {
     }
 
     /// Returns the index which corresponds to the given offset into the protected data.
-    fn offset_to_index(&self, offset: usize) -> BinTreeIndex {
-        BinTreeIndex(exp2(self.generations()) - 1 + offset / self.sector_sz)
+    fn offset_to_index(&self, offset: usize) -> Result<BinTreeIndex> {
+        let gens = self.generations();
+        let sector_index = offset / self.sector_sz;
+        let index_limit = exp2(gens);
+        if sector_index >= index_limit {
+            return Err(Error::InvalidOffset { actual: offset, limit: index_limit * self.sector_sz })
+        }
+        Ok(BinTreeIndex(exp2(gens) - 1 + sector_index))
     }
 
     /// Returns an iterator of slices which need to be hashed along with the data to create a leaf
@@ -1427,35 +1462,48 @@ impl<T: MerkleNode> MerkleTree<T> {
 
     /// Hashes the given data, adds a new node to the tree with its hash and updates the hashes
     /// of all parent nodes.
-    fn push(&mut self, data: &[u8]) -> Result<()> {
+    fn write(&mut self, offset: usize, data: &[u8]) -> Result<()> {
         self.assert_sector_sz(data.len())?;
 
-        if self.nodes.is_empty() {
-            self.nodes.push(T::new(Self::leaf_parts(data))?);
-            return Ok(())
-        }
-
+        let sector_index = offset / self.sector_sz;
         let generations = self.generations();
-        if Self::capacity(generations) == self.nodes.len() {
+        let sector_index_sup = exp2(generations);
+        if sector_index >= sector_index_sup {
             // Need to resize the tree.
-            let new_sz = Self::capacity(generations + 1);
-            self.nodes.reserve_exact(new_sz - self.nodes.len());
-            // Extend the vector so that half of the new generation is allocated.
-            self.nodes.resize_with(self.nodes.len() + exp2(generations), T::default);
+            let generations_new = log2(sector_index) + 1;
+            let new_cap = Self::len(generations_new) - self.nodes.len();
+            self.nodes.reserve_exact(new_cap);
+            // Extend the vector so there is enough room to fit the current leaves in the last
+            // generation.
+            let leaf_ct = self.nodes.len() - Self::len(generations - 1);
+            let new_len = Self::len(generations_new - 1) + sector_index + 1;
+            self.nodes.resize_with(new_len, T::default);
             // Shift all previously allocated nodes down the tree.
-            for k in (0..(generations + 1)).rev() {
-                let shift = exp2(k);
-                let start = exp2(k) - 1;
-                for index in start..(2*start + 1) {
+            let generation_gap = generations_new - generations;
+            for gen in (0..(generations + 1)).rev() {
+                let shift = exp2(gen + generation_gap) - exp2(gen);
+                let start = exp2(gen) - 1;
+                let end = start + if gen == generations { leaf_ct } else { exp2(gen) };
+                for index in start..end {
                     let new_index = index + shift;
                     self.nodes.swap(index, new_index);
                 }
             }
+            // Percolate up the old root to ensure that all nodes on the path from the old
+            // root to the new root are initialized. This is not needed in the case where the
+            // generation gap is only 1, as only the root is uninitialized in this case and it will
+            // be initialized after inserting the new node below.
+            if generation_gap > 1 && generations >= 0 {
+                self.perc_up(BinTreeIndex(exp2(generation_gap) - 1))?;
+            }
         }
 
-        self.nodes.push(T::new(Self::leaf_parts(data))?);
-        self.perc_up(self.end())?;
-        Ok(())
+        let index = self.offset_to_index(offset)?;
+        if index.0 >= self.nodes.len() {
+            self.nodes.resize_with(index.0 + 1, T::default);
+        }
+        self.nodes[index.0] = T::new(Self::leaf_parts(data))?;
+        self.perc_up(index)
     }
 
     /// Percolates up the hash change to the given node to the root.
@@ -1480,9 +1528,8 @@ impl<T: MerkleNode> MerkleTree<T> {
             .combine(Self::interior_prefix(), left, right)
             .map_err(|_| Error::IndexOutOfBounds {
                 index: index.0,
-                limit: Self::capacity(self.generations() - 1)
-            })?;
-        Ok(())
+                limit: Self::len(self.generations() - 1)
+            })
     }
 
     /// Verifies that the given data stored from the given offset into the protected data, has not
@@ -1492,7 +1539,7 @@ impl<T: MerkleNode> MerkleTree<T> {
             return Err(Error::RootHashNotVerified)
         }
         self.assert_sector_sz(data.len())?;
-        let start = self.offset_to_index(offset);
+        let start = self.offset_to_index(offset)?;
         self.hash_at(start)?.assert_contains_hash_of(Self::leaf_parts(data))?;
         for index in start.ancestors() {
             let parent = self.hash_at(index)?;
@@ -1504,83 +1551,92 @@ impl<T: MerkleNode> MerkleTree<T> {
     }
 }
 
-struct MerkleStream<T> {
+impl<T> Sectored for MerkleTree<T> {
+    fn sector_sz(&self) -> usize {
+        self.sector_sz
+    }
+}
+
+struct MerkleStream<T, H> {
     inner: T,
-    pos: usize,
-    sector_len: usize,
-    //tree: ???,
+    tree: MerkleTree<H>,
+    offset: usize,
 }
 
-impl MerkleStream<()> {
-    fn new() -> Self {
+impl<H> MerkleStream<(), H> {
+    fn new(tree: MerkleTree<H>) -> Self {
         MerkleStream {
             inner: (),
-            pos: 0,
-            sector_len: SECTOR_SZ_DEFAULT,
+            tree,
+            offset: 0,
         }
     }
 }
 
-impl<T> Compose<T, MerkleStream<T>> for MerkleStream<()> {
-    fn compose(self, inner: T) -> MerkleStream<T> {
+impl<T, H> Sectored for MerkleStream<T, H> {
+    fn sector_sz(&self) -> usize {
+        self.tree.sector_sz()
+    }
+}
+
+impl<T, H> Compose<T, MerkleStream<T, H>> for MerkleStream<(), H> {
+    fn compose(self, inner: T) -> MerkleStream<T, H> {
         MerkleStream {
             inner,
-            pos: 0,
-            sector_len: self.sector_len,
+            tree: self.tree,
+            offset: self.offset,
         }
     }
 }
 
-impl<T: Write> Write for MerkleStream<T> {
-    fn write(&mut self, _buf: &[u8]) -> io::Result<usize> {
-        unimplemented!()
+impl<T: Write, H: MerkleNode> Write for MerkleStream<T, H> {
+    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
+        self.assert_sector_sz(buf.len())?;
+        self.tree.write(self.offset, buf)?;
+        let written = self.inner.write(buf)?;
+        self.offset += self.sector_sz();
+        Ok(written)
     }
 
     fn flush(&mut self) -> io::Result<()> {
-        unimplemented!()
+        self.inner.flush()
     }
 }
 
-impl<T: Read> Read for MerkleStream<T> {
-    fn read(&mut self, _buf: &mut [u8]) -> io::Result<usize> {
-        unimplemented!()
+impl<T: Read, H: MerkleNode> Read for MerkleStream<T, H> {
+    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
+        self.assert_sector_sz(buf.len())?;
+        self.inner.read_exact(buf)?;
+        self.tree.verify(self.offset, buf)?;
+        self.offset += self.sector_sz();
+        Ok(self.sector_sz())
     }
 }
 
-impl<T: Seek> Seek for MerkleStream<T> {
-    fn seek(&mut self, _pos: io::SeekFrom) -> io::Result<u64> {
-        unimplemented!()
+impl<T: Seek, H> Seek for MerkleStream<T, H> {
+    fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> {
+        let from_start = self.inner.seek(pos)?;
+        self.offset = from_start as usize;
+        Ok(from_start)
     }
 }
 
 // A stream which encrypts all data written to it and decrypts all data read from it.
 struct SecretStream<T> {
     inner: T,
+    // The sector size of the inner stream. Reads and writes are only executed using buffers of
+    // this size.
+    inner_sect_sz: usize,
+    // The sector size of this stream. Reads and writes are only accepted for buffers of this size.
+    sect_sz: usize,
     key: SymKey,
     /// Buffer for ciphertext.
     ct_buf: Vec<u8>,
     /// Buffer for plaintext.
     pt_buf: Vec<u8>,
-    // The sector size of this stream. Reads and writes are only accepted for buffers of this size.
-    sect_sz: usize,
-    // The sector size of the inner stream. Reads and writes are only executed using buffers of
-    // this size.
-    inner_sect_sz: usize,
 }
 
 impl<T> SecretStream<T> {
-    fn new(key: SymKey, inner: T, inner_sect_sz: usize) -> Result<SecretStream<T>> {
-        let expansion_sz = key.expansion_sz();
-        let sect_sz = inner_sect_sz - expansion_sz;
-        let block_sz = key.block_size();
-        if 0 != sect_sz % block_sz {
-            return Err(Error::IndivisibleSize { divisor: block_sz, actual: sect_sz })
-        }
-        let ct_buf = vec![0u8; inner_sect_sz];
-        let pt_buf = vec![0u8; inner_sect_sz + block_sz];
-        Ok(SecretStream { inner, key, ct_buf, pt_buf, sect_sz, inner_sect_sz, })
-    }
-
     /// Given an offset into this stream, produces the corresponding offset into the inner stream.
     fn inner_offset(&self, outer_offset: u64) -> u64{
         let sect_sz = self.sect_sz as u64;
@@ -1597,13 +1653,34 @@ impl<T> SecretStream<T> {
     }
 }
 
+impl SecretStream<()> {
+    fn new(key: SymKey) -> SecretStream<()> {
+        SecretStream {
+            inner: (),
+            inner_sect_sz: 0,
+            sect_sz: 0,
+            key,
+            ct_buf: Vec::new(),
+            pt_buf: Vec::new()
+        }
+    }
+}
+
 impl<T, U: Sectored> Compose<U, SecretStream<U>> for SecretStream<T> {
-    fn compose(self, inner: U) -> SecretStream<U> {
+    fn compose(mut self, inner: U) -> SecretStream<U> {
         let inner_sect_sz = inner.sector_sz();
+        let expansion_sz = self.key.expansion_sz();
+        let sect_sz = inner_sect_sz - expansion_sz;
+        let block_sz = self.key.block_size();
+        if 0 != sect_sz % block_sz {
+            panic!("{}", Error::IndivisibleSize { divisor: block_sz, actual: sect_sz })
+        }
+        self.pt_buf.resize(inner_sect_sz, 0);
+        self.pt_buf.resize(inner_sect_sz + block_sz, 0);
         SecretStream {
-            sect_sz: inner_sect_sz - self.key.expansion_sz(),
-            inner_sect_sz,
             inner,
+            inner_sect_sz,
+            sect_sz: inner_sect_sz - expansion_sz,
             key: self.key,
             ct_buf: self.ct_buf,
             pt_buf: self.pt_buf,
@@ -1880,8 +1957,11 @@ pub fn verify_header(_header: &Header, _sig: &Signature) -> Result<()> {
 
 #[cfg(test)]
 mod tests {
+    use crate::{
+        SECTOR_SZ_DEFAULT,
+        test_helpers::*,
+    };
     use super::*;
-    use super::super::test_helpers::*;
     use std::{
         time::Duration, io::{Cursor, SeekFrom},
     };
@@ -2059,49 +2139,112 @@ mod tests {
         assert!(result.is_err())
     }
 
-    fn secret_stream(key_kind: SymKeyKind, num_sectors: usize) -> SecretStream<Cursor<Vec<u8>>> {
-        let key = SymKey::generate(key_kind).expect("key generation failed");
-        let inner = Cursor::new(vec![0u8; num_sectors * SECTOR_SZ_DEFAULT]);
-        SecretStream::new(key, inner, SECTOR_SZ_DEFAULT)
-            .expect("secret_stream creation failed")
-    }
-
-    fn secret_stream_encrypt_decrypt_are_inverse_test_case(key_kind: SymKeyKind) {
-        let mut stream = secret_stream(key_kind, 3);
+    fn secret_stream_sequential_test_case(
+        key: SymKey, inner_sect_sz: usize, sect_ct: usize
+    ) {
+        let mut stream = SecretStream::new(key)
+            .compose(SectoredCursor::new(vec![0u8; inner_sect_sz * sect_ct], inner_sect_sz));
         let sector_sz = stream.sector_sz();
-        let expected = vec![1u8; 3 * sector_sz];
-
-        for sector in expected.chunks(sector_sz) {
-            stream.write(sector).expect("write failed");
+        for k in 0..sect_ct {
+            let sector = vec![k as u8; sector_sz];
+            stream.write(&sector).expect("write failed");
         }
         stream.seek(SeekFrom::Start(0)).expect("seek failed");
-        let mut actual = vec![0u8; 3 * sector_sz];
-        for sector in actual.chunks_mut(sector_sz) {
-            stream.read(sector).expect("read failed");
+        for k in 0..sect_ct {
+            let expected = vec![k as u8; sector_sz];
+            let mut actual = vec![0u8; sector_sz];
+            stream.read(&mut actual).expect("read failed");
+            assert_eq!(expected, actual);
         }
+    }
 
-        assert_eq!(expected, actual);
+    fn secret_stream_sequential_test_suite(kind: SymKeyKind) {
+        let key = SymKey::generate(kind).expect("key generation failed");
+        secret_stream_sequential_test_case(key.clone(), SECTOR_SZ_DEFAULT, 16);
     }
 
     #[test]
     fn secret_stream_encrypt_decrypt_are_inverse_aes256cbc() {
-        secret_stream_encrypt_decrypt_are_inverse_test_case(SymKeyKind::Aes256Cbc)
+        secret_stream_sequential_test_suite(SymKeyKind::Aes256Cbc)
     }
 
     #[test]
     fn secret_stream_encrypt_decrypt_are_inverse_aes256ctr() {
-        secret_stream_encrypt_decrypt_are_inverse_test_case(SymKeyKind::Aes256Ctr)
+        secret_stream_sequential_test_suite(SymKeyKind::Aes256Ctr)
+    }
+
+    fn secret_stream_random_access_test_case(
+        rando: Randomizer, key: SymKey, inner_sect_sz: usize, sect_ct: usize
+    ) {
+        let mut stream = SecretStream::new(key)
+            .compose(SectoredCursor::new(vec![0u8; inner_sect_sz * sect_ct], inner_sect_sz));
+        let sect_sz = stream.sector_sz();
+        let indices: Vec<usize> = rando.take(sect_ct).map(|e| e % sect_ct).collect();
+        for index in indices.iter().map(|e| *e) {
+            let offset = index * sect_sz;
+            stream.seek(SeekFrom::Start(offset as u64)).expect("seek to write failed");
+            let sector = vec![index as u8; sect_sz];
+            stream.write(&sector).expect("write failed");
+        }
+        for index in indices.iter().map(|e| *e) {
+            let offset = index * sect_sz;
+            stream.seek(SeekFrom::Start(offset as u64)).expect("seek to read failed");
+            let expected = vec![index as u8; sect_sz];
+            let mut actual = vec![0u8; sect_sz];
+            stream.read(&mut actual).expect("read failed");
+            assert_eq!(expected, actual);
+        }
+    }
+
+    fn secret_stream_random_access_test_suite(kind: SymKeyKind) {
+        const SEED: [u8; Randomizer::HASH.len()] = [3u8; Randomizer::HASH.len()];
+        let key = SymKey::generate(kind).expect("key generation failed");
+        secret_stream_random_access_test_case(
+            Randomizer::new(SEED), key.clone(), SECTOR_SZ_DEFAULT, 20
+        );
+        secret_stream_random_access_test_case(
+            Randomizer::new(SEED), key.clone(), SECTOR_SZ_DEFAULT, 800
+        );
+        secret_stream_random_access_test_case(Randomizer::new(SEED), key.clone(), 512, 200);
+        secret_stream_random_access_test_case(Randomizer::new(SEED), key.clone(), 512, 20);
+        secret_stream_random_access_test_case(Randomizer::new(SEED), key.clone(), 512, 200);
+    }
+
+    #[test]
+    fn secret_stream_random_access() {
+        secret_stream_random_access_test_suite(SymKeyKind::Aes256Cbc);
+        secret_stream_random_access_test_suite(SymKeyKind::Aes256Ctr);
+    }
+
+    fn make_secret_stream(
+        key_kind: SymKeyKind, num_sectors: usize
+    ) -> SecretStream<SectoredCursor<Vec<u8>>> {
+        let key = SymKey::generate(key_kind).expect("key generation failed");
+        let inner = SectoredCursor::new(
+            vec![0u8; num_sectors * SECTOR_SZ_DEFAULT],
+            SECTOR_SZ_DEFAULT
+        );
+        SecretStream::new(key).compose(inner)
     }
 
     #[test]
     fn secret_stream_seek_from_start() {
-        let mut stream = secret_stream(SymKeyKind::Aes256Cbc, 3);
+        let mut stream = make_secret_stream(SymKeyKind::Aes256Cbc, 3);
         let sector_sz = stream.sector_sz();
         let expected = vec![2u8; sector_sz];
         // Write one sector of ones, one sector of twos and one sector of threes.
         for k in 1..4 {
             let sector: Vec<u8> = std::iter::repeat(k as u8).take(sector_sz).collect();
-            stream.write(&sector).expect("writing to stream failed");
+            stream.write(&sector).expect("writing to stream failed");fn secret_stream(
+                key_kind: SymKeyKind, num_sectors: usize
+            ) -> SecretStream<SectoredCursor<Vec<u8>>> {
+                let key = SymKey::generate(key_kind).expect("key generation failed");
+                let inner = SectoredCursor::new(
+                    vec![0u8; num_sectors * SECTOR_SZ_DEFAULT],
+                    SECTOR_SZ_DEFAULT
+                );
+                SecretStream::new(key).compose(inner)
+            }
         }
 
         stream.seek(SeekFrom::Start(sector_sz as u64)).expect("seek failed");
@@ -2114,7 +2257,7 @@ mod tests {
 
     #[test]
     fn secret_stream_seek_from_current() {
-        let mut stream = secret_stream(SymKeyKind::Aes256Cbc, 3);
+        let mut stream = make_secret_stream(SymKeyKind::Aes256Cbc, 3);
         let sector_sz = stream.sector_sz();
         let expected = vec![3u8; sector_sz];
         // Write one sector of ones, one sector of twos and one sector of threes.
@@ -2133,7 +2276,7 @@ mod tests {
 
     #[test]
     fn log2_test() {
-        assert_eq!(0, log2(0));
+        assert_eq!(-1, log2(0));
         assert_eq!(0, log2(1));
         assert_eq!(1, log2(2));
         assert_eq!(2, log2(4));
@@ -2141,16 +2284,17 @@ mod tests {
         assert_eq!(3, log2(8));
         assert_eq!(9, log2(1023));
         assert_eq!(10, log2(1025));
-        assert_eq!(30, log2(1073741824));
+        assert_eq!(63, log2(usize::MAX));
     }
 
     fn make_tree_with<const SZ: usize>(num_sects: usize) -> (MerkleTree<Sha2_256>, Vec<[u8; SZ]>) {
         let mut tree = MerkleTree::<Sha2_256>::empty(SZ);
         let mut sectors = Vec::with_capacity(num_sects);
         for k in 1..(num_sects + 1) {
+            let offset = SZ * (k - 1);
             let sector = [k as u8; SZ];
             sectors.push(sector);
-            tree.push(&sector).expect("append sector failed");
+            tree.write(offset, &sector).expect("append sector failed");
         }
         (tree, sectors)
     }
@@ -2190,9 +2334,9 @@ mod tests {
         let one = [1u8; SZ];
         let mut two = [2u8; SZ];
         let three = [3u8; SZ];
-        tree.push(&one).expect("append one failed");
-        tree.push(&two).expect("append two failed");
-        tree.push(&three).expect("append three failed");
+        tree.write(0, &one).expect("append one failed");
+        tree.write(SZ, &two).expect("append two failed");
+        tree.write(2 * SZ, &three).expect("append three failed");
 
         two[0] = 7u8;
 
@@ -2208,15 +2352,74 @@ mod tests {
         let one = [1u8; SZ];
         let two = [2u8; SZ];
         let three = [3u8; SZ];
-        tree.push(&one).expect("append one failed");
-        tree.push(&two).expect("append two failed");
-        tree.push(&three).expect("append three failed");
+        tree.write(0, &one).expect("append one failed");
+        tree.write(SZ, &two).expect("append two failed");
+        tree.write(2 * SZ, &three).expect("append three failed");
         let vec = to_vec(&tree).expect("to_vec failed");
         let tree: MerkleTree::<Sha2_256> = from_vec(&vec).expect("from_vec failed");
 
         tree.verify(SZ, &two).expect_err("verify succeeded, though it should have failed");
     }
 
+    fn merkle_stream_sequential_test_case(sect_sz: usize, sect_count: usize) {
+        let mut stream = MerkleStream
+            ::new(MerkleTree::<Sha2_256>::empty(sect_sz))
+            .compose(Cursor::new(vec![0u8; sect_count * sect_sz]));
+        for k in 1..(sect_count + 1) {
+            let sector = vec![k as u8; sect_sz];
+            stream.write(&sector).expect("write failed");
+        }
+        stream.seek(SeekFrom::Start(0)).expect("seek failed");
+        for k in 1..(sect_count + 1) {
+            let expected = vec![k as u8; sect_sz];
+            let mut actual = vec![0u8; sect_sz];
+            stream.read(&mut actual).expect("read failed");
+            assert_eq!(expected, actual);
+        }
+    }
+
+    #[test]
+    fn merkle_stream_sequential() {
+        merkle_stream_sequential_test_case(SECTOR_SZ_DEFAULT, 20);
+        merkle_stream_sequential_test_case(SECTOR_SZ_DEFAULT, 200);
+        merkle_stream_sequential_test_case(SECTOR_SZ_DEFAULT, 800);
+        merkle_stream_sequential_test_case(512, 25);
+        merkle_stream_sequential_test_case(8192, 20);
+    }
+
+    fn merkle_stream_random_test_case(rando: Randomizer, sect_sz: usize, sect_ct: usize) {
+        let mut stream = MerkleStream
+            ::new(MerkleTree::<Sha2_256>::empty(sect_sz))
+            .compose(Cursor::new(vec![0u8; sect_sz * sect_ct]));
+        let indices: Vec<usize> = rando.take(sect_ct).map(|e| e % sect_ct).collect();
+        for index in indices.iter().map(|e| *e) {
+            let offset = sect_sz * index;
+            stream.seek(SeekFrom::Start(offset as u64)).expect("seek to write failed");
+            let sector = vec![index as u8; sect_sz];
+            stream.write(&sector).expect("write failed");
+        }
+        for index in indices.iter().map(|e| *e) {
+            let offset = sect_sz * index;
+            stream.seek(SeekFrom::Start(offset as u64)).expect("seek to read failed");
+            let expected = vec![index as u8; sect_sz];
+            let mut actual = vec![0u8; sect_sz];
+            stream.read(&mut actual).expect("read failed");
+            assert_eq!(expected, actual);
+        }
+    }
+
+    #[test]
+    fn merkle_stream_random() {
+        const SEED: [u8; Randomizer::HASH.len()] = [3u8; Randomizer::HASH.len()];
+        merkle_stream_random_test_case(Randomizer::new(SEED), SECTOR_SZ_DEFAULT, 2);
+        merkle_stream_random_test_case(Randomizer::new(SEED), SECTOR_SZ_DEFAULT, 4);
+        merkle_stream_random_test_case(Randomizer::new(SEED), SECTOR_SZ_DEFAULT, 8);
+        merkle_stream_random_test_case(Randomizer::new(SEED), SECTOR_SZ_DEFAULT, 20);
+        merkle_stream_random_test_case(Randomizer::new(SEED), SECTOR_SZ_DEFAULT, 200);
+        merkle_stream_random_test_case(Randomizer::new(SEED), SECTOR_SZ_DEFAULT, 800);
+        merkle_stream_random_test_case(Randomizer::new(SEED), 8192, 63);
+    }
+
     /// Tests that validate the dependencies of this module.
     mod dependency_tests {
         use super::*;

+ 70 - 1
crates/btnode/src/test_helpers.rs

@@ -3,7 +3,7 @@
 use super::*;
 use crypto::*;
 use serde_block_tree::{Error, Result};
-use std::{ fs::File, io::Write, fmt::Write as FmtWrite };
+use std::{ fs::File, io::{Write, Cursor}, fmt::Write as FmtWrite };
 
 pub const PRINCIPAL: [u8; 32] = [
     0x75, 0x28, 0xA9, 0xE0, 0x9D, 0x24, 0xBA, 0xB3, 0x79, 0x56, 0x15, 0x68, 0xFD, 0xA4, 0xE2, 0xA4,
@@ -527,4 +527,73 @@ fn write_slice<W: Write>(output: &mut W, name: &str, slice: &[u8]) -> Result<()>
     writeln!(output, "];").map_err(Error::Io)?;
     writeln!(output).map_err(Error::Io)?;
     Ok(())
+}
+
+/// A naive randomizer implementation that is intended only for testing.
+pub struct Randomizer {
+    state: [u8; Self::HASH.len()],
+    buf: [u8; Self::HASH.len()],
+}
+
+impl Randomizer {
+    pub const HASH: HashKind = HashKind::Sha2_256;
+
+    pub fn new(seed: [u8; Self::HASH.len()]) -> Randomizer {
+        Randomizer { state: seed, buf: [0u8; Self::HASH.len()] }
+    }
+}
+
+impl Iterator for Randomizer {
+    type Item = usize;
+
+    fn next(&mut self) -> Option<Self::Item> {
+        const BYTES: usize = usize::BITS as usize / 8;
+        Self::HASH.digest(&mut self.buf, std::iter::once(self.state.as_slice()))
+            .expect("digest failed");
+        self.state.copy_from_slice(&self.buf);
+        let int_bytes = self.buf
+            .as_slice()[..BYTES]
+            .try_into()
+            .expect("failed to convert array");
+        Some(usize::from_ne_bytes(int_bytes))
+    }
+}
+
+pub struct SectoredCursor<T> {
+    cursor: Cursor<T>,
+    sect_sz: usize,
+}
+
+impl<T> SectoredCursor<T> {
+    pub fn new(inner: T, sect_sz: usize) -> SectoredCursor<T> {
+        SectoredCursor { cursor: Cursor::new(inner), sect_sz }
+    }
+}
+
+impl<T> Sectored for SectoredCursor<T> {
+    fn sector_sz(&self) -> usize {
+        self.sect_sz
+    }
+}
+
+impl Write for SectoredCursor<Vec<u8>> {
+    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
+        self.cursor.write(buf)
+    }
+
+    fn flush(&mut self) -> io::Result<()> {
+        self.cursor.flush()
+    }
+}
+
+impl<T: AsRef<[u8]>> Read for SectoredCursor<T> {
+    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
+        self.cursor.read(buf)
+    }
+}
+
+impl<T: AsRef<[u8]>> Seek for SectoredCursor<T> {
+    fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
+        self.cursor.seek(pos)
+    }
 }