Browse Source

Implemented a Merkle tree to be used with `MerkleStream`.

Matthew Carr 2 years ago
parent
commit
0196baa710
1 changed files with 498 additions and 7 deletions
  1. 498 7
      crates/btnode/src/crypto/mod.rs

+ 498 - 7
crates/btnode/src/crypto/mod.rs

@@ -34,7 +34,7 @@ use openssl::{
     symm::{Cipher, encrypt as openssl_encrypt, decrypt as openssl_decrypt, Crypter, Mode},
     rand::rand_bytes,
     rsa::{Rsa as OsslRsa, Padding as OpensslPadding},
-    hash::{hash, MessageDigest},
+    hash::{hash, MessageDigest, Hasher},
     sign::{Signer as OsslSigner, Verifier as OsslVerifier}
 };
 use serde_block_tree::{self, to_vec, from_vec, write_to};
@@ -69,7 +69,10 @@ pub enum Error {
     InvalidHashFormat,
     InvalidSignature,
     IncorrectSize { expected: usize, actual: usize },
+    IndexOutOfBounds { index: usize, limit: usize },
     IndivisibleSize { divisor: usize, actual: usize },
+    HashCmpFailure,
+    RootHashNotVerified,
     WritecapAuthzErr(WritecapAuthzErr),
     Serde(serde_block_tree::Error),
     Io(std::io::Error),
@@ -94,11 +97,19 @@ impl Display for Error {
             Error::InvalidSignature => write!(f, "invalid signature"),
             Error::IncorrectSize { expected, actual }
                 => write!(f, "expected size {} but got {}", expected, actual),
-            Error::IndivisibleSize { divisor, actual }
+            Error::IndexOutOfBounds { index, limit }
                 => write!(
                     f,
-                    "expected a size which is divisible by {} but got {}", divisor, actual
+                    "index {} is out of bounds, it must be strictly less than {}",
+                    index,
+                    limit
                 ),
+            Error::IndivisibleSize { divisor, actual }
+                => write!(
+                    f, "expected a size which is divisible by {} but got {}", divisor, actual
+                ),
+            Error::HashCmpFailure => write!(f, "hash data are not equal"),
+            Error::RootHashNotVerified => write!(f, "root hash is not verified"),
             Error::WritecapAuthzErr(err) => err.fmt(f),
             Error::Serde(err) => err.fmt(f),
             Error::Io(err) => err.fmt(f),
@@ -166,9 +177,9 @@ fn rand_vec(len: usize) -> Result<Vec<u8>> {
 #[strum_discriminants(derive(EnumString, Display, Serialize, Deserialize))]
 #[strum_discriminants(name(HashKind))]
 pub(crate) enum Hash {
-    Sha2_256([u8; 32]),
+    Sha2_256([u8; HashKind::Sha2_256.len()]),
     #[serde(with = "BigArray")]
-    Sha2_512([u8; 64]),
+    Sha2_512([u8; HashKind::Sha2_512.len()]),
 }
 
 impl Default for HashKind {
@@ -208,8 +219,8 @@ impl Hash {
 impl From<HashKind> for Hash {
     fn from(discriminant: HashKind) -> Hash {
         match discriminant {
-            HashKind::Sha2_512 => Hash::Sha2_512([0u8; 64]),
-            HashKind::Sha2_256 => Hash::Sha2_256([0u8; 32])
+            HashKind::Sha2_256 => Hash::Sha2_256([0u8; HashKind::Sha2_256.len()]),
+            HashKind::Sha2_512 => Hash::Sha2_512([0u8; HashKind::Sha2_512.len()]),
         }
     }
 }
@@ -1099,6 +1110,400 @@ pub(crate) trait CredStore {
     ) -> Result<Self::CredHandle>;
 }
 
+/// Returns the base 2 logarithm of the given number. This function will return 0 when given 0.
+fn log2(mut n: usize) -> usize {
+    // Is there a better implementation of this in std? I wasn't able to find an integer log2
+    // function in std, so I wrote this naive implementation.
+    if 0 == n {
+        return 0;
+    }
+    let num_bits = 8 * std::mem::size_of::<usize>(); 
+    for k in 0..num_bits {
+        n >>= 1;
+        if 0 == n {
+            return k;
+        }
+    }
+    num_bits
+}
+
+/// Returns 2^x.
+fn exp2(x: usize) -> usize {
+    1 << x
+}
+
+/// Trait for types which can be used as nodes in a `MerkleTree`.
+trait MerkleNode: Default + Serialize + for<'de> Deserialize<'de> {
+    /// The kind of hash algorithm that this `HashData` uses.
+    const KIND: HashKind;
+
+    /// Computes the hash of the data produced by the given iterator and writes it to the
+    /// given slice.
+    fn digest<'a, I: Iterator<Item = &'a [u8]>>(dest: &mut [u8], parts: I) -> Result<()>;
+
+    /// Creates a new `HashData` instance by hashing the data produced by the given iterator and
+    /// storing it in self.
+    fn new<'a, I: Iterator<Item = &'a [u8]>>(parts: I) -> Result<Self>;
+
+    /// Combines the hash data from the given children and prefix and stores it in self. It is
+    /// an error for no children to be provided (though one or the other may be `None`).
+    fn combine<'a, I: Iterator<Item = &'a [u8]>>(
+        &mut self, prefix: I, left: Option<&'a Self>, right: Option<&'a Self>
+    ) -> Result<()>;
+
+    /// Returns `Ok(())` if self contains the given hash data, and `Err(Error::HashCmpFailure)`
+    /// otherwise.
+    fn assert_contains(&self, hash_data: &[u8]) -> Result<()>;
+
+    /// Returns `Ok(())` if self contains the hash of the given data. Otherwise,
+    /// `Err(Error::HashCmpFailure)` is returned.
+    fn assert_contains_hash_of<'a, I: Iterator<Item = &'a [u8]>>(&self, parts: I) -> Result<()>;
+
+    /// Returns `Ok(())` if the result of combining left and right is contained in self.
+    fn assert_parent_of<'a, I: Iterator<Item = &'a [u8]>>(
+        &self, prefix: I, left: Option<&'a Self>, right: Option<&'a Self>
+    ) -> Result<()>;
+}
+
+// TODO: Once full const generic support lands we can use a HashKind as a const param. Then we won't
+// need to have different structs to support different kinds of hashes. 
+/// A struct for storing SHA2 256 hashes in a `MerkleTree`.
+#[derive(Default, Serialize, Deserialize)]
+struct Sha2_256(Option<[u8; HashKind::Sha2_256.len()]>);
+
+impl Sha2_256 {
+    fn as_slice(&self) -> Option<&[u8]> {
+        self.0.as_ref().map(|e| e.as_slice())
+    }
+
+    fn as_mut_slice(&mut self) -> Option<&mut [u8]> {
+        self.0.as_mut().map(|e| e.as_mut_slice())
+    }
+
+    /// Returns a mutable reference to the array contained in self, if the array already exists.
+    /// Otherwise, creates a new array filled with zeros owned by self and returns a
+    /// reference.
+    fn mut_or_init(&mut self) -> &mut [u8] {
+        if self.0.is_none() {
+            self.0 = Some([0; HashKind::Sha2_256.len()])
+        }
+        self.0.as_mut().unwrap()
+    }
+
+    // I think this is the most complicated function signature I've ever written in any language.
+    /// Combines the given slices, together with the given prefix, and stores the resulting hash
+    /// in `dest`. If neither `left` nor `right` is `Some`, then `when_neither` is called and
+    /// whatever it returns is returned by this method.
+    fn combine_hash_data<'a, I: Iterator<Item = &'a [u8]>, F: FnOnce() -> Result<()>>(
+        dest: &mut [u8], prefix: I, left: Option<&'a [u8]>, right: Option<&'a [u8]>, when_neither: F
+    ) -> Result<()> {
+        match (left, right) {
+            (Some(left), Some(right))
+                => Self::digest(dest, prefix.chain([left, right].into_iter())),
+            (Some(left), None)
+                => Self::digest(dest, prefix.chain([left].into_iter())),
+            (None, Some(right))
+                => Self::digest(dest, prefix.chain([right].into_iter())),
+            (None, None) => when_neither(),
+        }
+    }
+}
+
+impl MerkleNode for Sha2_256 {
+    const KIND: HashKind = HashKind::Sha2_256;
+
+    fn digest<'a, I: Iterator<Item = &'a [u8]>>(dest: &mut [u8], parts: I) -> Result<()> {
+        if dest.len() != Self::KIND.len() {
+            return Err(Error::IncorrectSize { expected: Self::KIND.len(), actual: dest.len() })
+        }
+        let mut hasher = Hasher::new(Self::KIND.into())?;
+        for part in parts {
+            hasher.update(part)?;
+        }
+        let hash = hasher.finish()?;
+        dest.copy_from_slice(&hash);
+        Ok(())
+    }
+
+    fn new<'a, I: Iterator<Item = &'a [u8]>>(parts: I) -> Result<Self> {
+        let mut array = [0u8; Self::KIND.len()];
+        Self::digest(&mut array, parts)?;
+        Ok(Sha2_256(Some(array)))
+    }
+
+    fn combine<'a , I: Iterator<Item = &'a [u8]>>(
+        &mut self, prefix: I, left: Option<&'a Self>, right: Option<&'a Self>
+    ) -> Result<()> {
+        let left = left.and_then(|e| e.as_slice());
+        let right = right.and_then(|e| e.as_slice());
+        Self::combine_hash_data(
+            self.mut_or_init(),
+            prefix,
+            left,
+            right,
+            || Err(Error::custom("at least one argument to combine needs to supply data"))
+        )
+    }
+
+    fn assert_contains(&self, hash_data: &[u8]) -> Result<()> {
+        if let Some(slice) = self.as_slice() {
+            if slice == hash_data {
+                return Ok(())
+            }
+        }
+        Err(Error::HashCmpFailure)
+    }
+
+    fn assert_contains_hash_of<'a, I: Iterator<Item = &'a [u8]>>(&self, parts: I) -> Result<()> {
+        let mut buf = [0u8; Self::KIND.len()];
+        Self::digest(&mut buf, parts)?;
+        self.assert_contains(&buf)
+    }
+
+    fn assert_parent_of<'a, I: Iterator<Item = &'a [u8]>>(
+        &self, prefix: I, left: Option<&'a Self>, right: Option<&'a Self>
+    ) -> Result<()> {
+        let slice = match self.as_slice() {
+            Some(slice) => slice,
+            None => return Err(Error::HashCmpFailure),
+        };
+        let buf = {
+            let mut buf = [0u8; Self::KIND.len()];
+            let left = left.and_then(|e| e.as_slice());
+            let right = right.and_then(|e| e.as_slice());
+            Self::combine_hash_data(
+                &mut buf, prefix, left, right, || Err(Error::custom("logic error encountered"))
+            )?;
+            buf
+        };
+        if slice == buf {
+            Ok(())
+        }
+        else {
+            Err(Error::HashCmpFailure)
+        }
+    }
+}
+
+/// An index into a binary tree. This type provides convenience methods for navigating a tree.
+#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
+struct BinTreeIndex(usize);
+
+impl BinTreeIndex {
+    /// Returns the index of the left child of this node.
+    fn left(self) -> Self {
+        Self(2 * self.0 + 1)
+    }
+
+    /// Returns the index of the right child of this node.
+    fn right(self) -> Self {
+        Self(2 * (self.0 + 1))
+    }
+
+    /// Returns the index of the parent of this node.
+    fn parent(self) -> Option<Self> {
+        if self.0 > 0 {
+            Some(Self((self.0 - 1) / 2))
+        }
+        else {
+            None
+        }
+    }
+
+    /// Returns an iterator over the indices of all of this node's ancestors.
+    fn ancestors(self) -> impl Iterator<Item = BinTreeIndex> {
+        struct ParentIter(Option<BinTreeIndex>);
+
+        impl Iterator for ParentIter {
+            type Item = BinTreeIndex;
+            
+            fn next(&mut self) -> Option<Self::Item> {
+                let parent = match self.0 {
+                    Some(curr) => curr.parent(),
+                    None => None,
+                };
+                self.0 = parent;
+                parent
+            }
+        }
+
+        ParentIter(Some(self))
+    }
+}
+
+/// An implementation of a Merkle tree, a tree for storing hashes. This implementation is a binary
+/// tree which stores its nodes in a vector to ensure data locality.
+/// 
+/// This type is used to provide integrity protection to a sequence of fixed sized units of data
+/// called sectors. The size of the sectors are determined when the tree is created and cannot
+/// be changed later. The hashes contained in the leaf nodes of this tree are hashes of sectors.
+/// Each sector corresponds to an offset into the protected data, and in order to verify that a
+/// sector has not been modified, you must supply the offset of the sector.
+#[derive(Serialize, Deserialize)]
+struct MerkleTree<T> {
+    nodes: Vec<T>,
+    /// The size of the sectors of data that this tree will protect.
+    sector_sz: usize,
+    #[serde(skip)]
+    root_verified: bool,
+}
+
+impl<T> MerkleTree<T> {
+    /// A byte to prefix data being hashed for leaf nodes.
+    const LEAF_PREFIX: &'static [u8] = &[0x00];
+    /// A byte to prefix data being hashed for interior nodes.
+    const INTERIOR_PREFIX: &'static [u8] = &[0x01];
+    
+    /// Creates a new tree with no nodes in it and the given sector size.
+    fn empty(sector_sz: usize) -> MerkleTree<T> {
+        MerkleTree { nodes: Vec::new(), sector_sz, root_verified: true }
+    }
+
+    /// Returns the number of generations in self.
+    fn generations(&self) -> usize {
+        log2(self.nodes.len())
+    }
+
+    fn assert_sector_sz(&self, len: usize) -> Result<()> {
+        if self.sector_sz == len {
+            Ok(())
+        }
+        else {
+            Err(Error::IncorrectSize { expected: self.sector_sz, actual: len })
+        }
+    }
+
+    /// Returns capacity needed to store a complete binary tree with the given number of
+    /// generations. Note that `generations` is 0-based, so a tree with 1 node has 0 generations,
+    /// and a tree with 3 has 1.
+    fn capacity(generations: usize) -> usize {
+        exp2(generations + 1) - 1
+    }
+
+    /// Returns the index of the last node in the tree.
+    fn end(&self) -> BinTreeIndex {
+        BinTreeIndex(self.nodes.len() - 1)
+    }
+
+    /// Returns a reference to the hash stored in the given node, or `Error::IndexOutOfBounds` if
+    /// the given index doesn't exist.
+    fn hash_at(&self, index: BinTreeIndex) -> Result<&T> {
+        self.nodes
+            .get(index.0)
+            .ok_or(Error::IndexOutOfBounds { index: index.0, limit: self.nodes.len() })
+    }
+
+    /// Returns the index which corresponds to the given offset into the protected data.
+    fn offset_to_index(&self, offset: usize) -> BinTreeIndex {
+        BinTreeIndex(exp2(self.generations()) - 1 + offset / self.sector_sz)
+    }
+
+    /// Returns an iterator of slices which need to be hashed along with the data to create a leaf
+    /// node.
+    fn leaf_parts(data: &[u8]) -> impl Iterator<Item = &[u8]> {
+        [Self::LEAF_PREFIX, data].into_iter()
+    }
+
+    /// Returns an iterator of slices which need to be hashed along with the data to create an
+    /// interior node.
+    fn interior_prefix<'a>() -> impl Iterator<Item = &'a [u8]> {
+        [Self::INTERIOR_PREFIX].into_iter()
+    }
+}
+
+impl<T: MerkleNode> MerkleTree<T> {
+    /// Checks that the root node contains the given hash data. If it does then `Ok(())` is
+    /// returned. If it doesn't, then `Err(Error::HashCmpFailure)` is returned.
+    fn assert_root_contains(&mut self, hash_data: &[u8]) -> Result<()> {
+        match self.hash_at(BinTreeIndex(0)) {
+            Ok(root) => {
+                root.assert_contains(hash_data)?;
+                self.root_verified = true;
+                Ok(())
+            }
+            Err(_) => Err(Error::HashCmpFailure),
+        }
+    }
+
+    /// Hashes the given data, adds a new node to the tree with its hash and updates the hashes
+    /// of all parent nodes.
+    fn push(&mut self, data: &[u8]) -> Result<()> {
+        self.assert_sector_sz(data.len())?;
+
+        if self.nodes.is_empty() {
+            self.nodes.push(T::new(Self::leaf_parts(data))?);
+            return Ok(())
+        }
+
+        let generations = self.generations();
+        if Self::capacity(generations) == self.nodes.len() {
+            // Need to resize the tree.
+            let new_sz = Self::capacity(generations + 1);
+            self.nodes.reserve_exact(new_sz - self.nodes.len());
+            // Extend the vector so that half of the new generation is allocated.
+            self.nodes.resize_with(self.nodes.len() + exp2(generations), T::default);
+            // Shift all previously allocated nodes down the tree.
+            for k in (0..(generations + 1)).rev() {
+                let shift = exp2(k);
+                let start = exp2(k) - 1;
+                for index in start..(2*start + 1) {
+                    let new_index = index + shift;
+                    self.nodes.swap(index, new_index);
+                }
+            }
+        }
+
+        self.nodes.push(T::new(Self::leaf_parts(data))?);
+        self.perc_up(self.end())?;
+        Ok(())
+    }
+
+    /// Percolates up the hash change to the given node to the root.
+    fn perc_up(&mut self, start: BinTreeIndex) -> Result<()> {
+        for index in start.ancestors() {
+            self.combine_children(index)?;
+        }
+        Ok(())
+    }
+
+    /// Combines the hashes of the given node's children and stores it in the given node.
+    fn combine_children(&mut self, index: BinTreeIndex) -> Result<()> {
+        let left = index.left();
+        let right = index.right(); 
+        // Note that index < left && index < right.
+        let split = index.0 + 1;
+        let (front, back) = self.nodes.split_at_mut(split);
+        let dest = &mut front[front.len() - 1];
+        let left = back.get(left.0 - split);
+        let right = back.get(right.0 - split);
+        dest
+            .combine(Self::interior_prefix(), left, right)
+            .map_err(|_| Error::IndexOutOfBounds {
+                index: index.0,
+                limit: Self::capacity(self.generations() - 1)
+            })?;
+        Ok(())
+    }
+
+    /// Verifies that the given data stored from the given offset into the protected data, has not
+    /// been modified.
+    fn verify(&self, offset: usize, data: &[u8]) -> Result<()> {
+        if !self.root_verified {
+            return Err(Error::RootHashNotVerified)
+        }
+        self.assert_sector_sz(data.len())?;
+        let start = self.offset_to_index(offset);
+        self.hash_at(start)?.assert_contains_hash_of(Self::leaf_parts(data))?;
+        for index in start.ancestors() {
+            let parent = self.hash_at(index)?;
+            let left = self.hash_at(index.left()).ok();
+            let right = self.hash_at(index.right()).ok();
+            parent.assert_parent_of(Self::interior_prefix(), left, right)?;
+        }
+        Ok(())
+    }
+}
+
 struct MerkleStream<T> {
     inner: T,
     pos: usize,
@@ -1726,6 +2131,92 @@ mod tests {
         assert_eq!(expected, actual);
     }
 
+    #[test]
+    fn log2_test() {
+        assert_eq!(0, log2(0));
+        assert_eq!(0, log2(1));
+        assert_eq!(1, log2(2));
+        assert_eq!(2, log2(4));
+        assert_eq!(2, log2(5));
+        assert_eq!(3, log2(8));
+        assert_eq!(9, log2(1023));
+        assert_eq!(10, log2(1025));
+        assert_eq!(30, log2(1073741824));
+    }
+
+    fn make_tree_with<const SZ: usize>(num_sects: usize) -> (MerkleTree<Sha2_256>, Vec<[u8; SZ]>) {
+        let mut tree = MerkleTree::<Sha2_256>::empty(SZ);
+        let mut sectors = Vec::with_capacity(num_sects);
+        for k in 1..(num_sects + 1) {
+            let sector = [k as u8; SZ];
+            sectors.push(sector);
+            tree.push(&sector).expect("append sector failed");
+        }
+        (tree, sectors)
+    }
+
+    fn merkle_tree_build_verify_test_case<const SZ: usize>(num_sects: usize) {
+        let (tree, sectors) = make_tree_with::<SZ>(num_sects);
+        for (k, sector) in sectors.into_iter().enumerate() {
+            tree.verify(k * SZ, &sector).expect("verify failed");
+        }
+    }
+
+    #[test]
+    fn merkle_tree_append_verify() {
+        merkle_tree_build_verify_test_case::<SECTOR_SZ_DEFAULT>(exp2(0));
+        merkle_tree_build_verify_test_case::<SECTOR_SZ_DEFAULT>(exp2(1));
+        merkle_tree_build_verify_test_case::<SECTOR_SZ_DEFAULT>(exp2(2));
+        merkle_tree_build_verify_test_case::<SECTOR_SZ_DEFAULT>(exp2(3));
+        merkle_tree_build_verify_test_case::<SECTOR_SZ_DEFAULT>(exp2(4));
+        merkle_tree_build_verify_test_case::<SECTOR_SZ_DEFAULT>(exp2(0) + 1);
+        merkle_tree_build_verify_test_case::<SECTOR_SZ_DEFAULT>(exp2(1) + 1);
+        merkle_tree_build_verify_test_case::<SECTOR_SZ_DEFAULT>(exp2(2) + 1);
+        merkle_tree_build_verify_test_case::<SECTOR_SZ_DEFAULT>(exp2(3) + 1);
+        merkle_tree_build_verify_test_case::<SECTOR_SZ_DEFAULT>(exp2(4) + 1);
+        merkle_tree_build_verify_test_case::<SECTOR_SZ_DEFAULT>(exp2(0) - 1);
+        merkle_tree_build_verify_test_case::<SECTOR_SZ_DEFAULT>(exp2(1) - 1);
+        merkle_tree_build_verify_test_case::<SECTOR_SZ_DEFAULT>(exp2(2) - 1);
+        merkle_tree_build_verify_test_case::<SECTOR_SZ_DEFAULT>(exp2(3) - 1);
+        merkle_tree_build_verify_test_case::<SECTOR_SZ_DEFAULT>(exp2(4) - 1);
+        merkle_tree_build_verify_test_case::<SECTOR_SZ_DEFAULT>(1337);
+        merkle_tree_build_verify_test_case::<512>(37);
+    }
+
+    #[test]
+    fn merkle_tree_data_changed_verify_fails() {
+        const SZ: usize = SECTOR_SZ_DEFAULT;
+        let mut tree = MerkleTree::<Sha2_256>::empty(SZ);
+        let one = [1u8; SZ];
+        let mut two = [2u8; SZ];
+        let three = [3u8; SZ];
+        tree.push(&one).expect("append one failed");
+        tree.push(&two).expect("append two failed");
+        tree.push(&three).expect("append three failed");
+
+        two[0] = 7u8;
+
+        tree.verify(0, &one).expect("failed to verify one");
+        tree.verify(SZ, &two).expect_err("verify two was expected to fail");
+        tree.verify(2 * SZ, &three).expect("failed to verify three");
+    }
+
+    #[test]
+    fn merkle_tree_root_not_verified_verify_fails() {
+        const SZ: usize = SECTOR_SZ_DEFAULT;
+        let mut tree = MerkleTree::<Sha2_256>::empty(SZ);
+        let one = [1u8; SZ];
+        let two = [2u8; SZ];
+        let three = [3u8; SZ];
+        tree.push(&one).expect("append one failed");
+        tree.push(&two).expect("append two failed");
+        tree.push(&three).expect("append three failed");
+        let vec = to_vec(&tree).expect("to_vec failed");
+        let tree: MerkleTree::<Sha2_256> = from_vec(&vec).expect("from_vec failed");
+
+        tree.verify(SZ, &two).expect_err("verify succeeded, though it should have failed");
+    }
+
     /// Tests that validate the dependencies of this module.
     mod dependency_tests {
         use super::*;