Browse Source

Performed some cleanup in the crypto module.

Matthew Carr 2 years ago
parent
commit
9fbf290044
1 changed files with 31 additions and 38 deletions
  1. 31 38
      crates/node/src/crypto.rs

+ 31 - 38
crates/node/src/crypto.rs

@@ -33,7 +33,7 @@ pub(crate) enum Error {
     MissingPrivateKey,
     KeyVariantUnsupported,
     BlockNotEncrypted,
-    InvalidFormat,
+    InvalidHashFormat,
     Message(String),
     Serde(serde_block_tree::Error),
 }
@@ -46,7 +46,7 @@ impl Display for Error {
             Error::MissingPrivateKey => write!(f, "private key was missing"),
             Error::KeyVariantUnsupported => write!(f, "unsupported key variant"),
             Error::BlockNotEncrypted => write!(f, "block was not encrypted"),
-            Error::InvalidFormat => write!(f, "invalid format"),
+            Error::InvalidHashFormat => write!(f, "invalid format"),
             Error::Message(message) => f.write_str(message.as_str()),
             Error::Serde(serde_err) => serde_err.fmt(f),
         }
@@ -125,14 +125,14 @@ impl TryFrom<&str> for Hash {
     fn try_from(string: &str) -> Result<Hash> {
         let mut split: Vec<&str> = string.split(Self::HASH_SEP).collect();
         if split.len() != 2 {
-            return Err(Error::InvalidFormat)
+            return Err(Error::InvalidHashFormat)
         };
-        let second = split.pop().ok_or(Error::InvalidFormat)?;
-        let first = split.pop().ok_or(Error::InvalidFormat)?;
+        let second = split.pop().ok_or(Error::InvalidHashFormat)?;
+        let first = split.pop().ok_or(Error::InvalidHashFormat)?;
         let mut hash = Hash::from(HashKind::from_str(first)
-            .map_err(|_| Error::InvalidFormat)?);
+            .map_err(|_| Error::InvalidHashFormat)?);
         base64_url::decode_to_slice(second, hash.as_mut())
-            .map_err(|_| Error::InvalidFormat)?;
+            .map_err(|_| Error::InvalidHashFormat)?;
         Ok(hash)
     }
 }
@@ -148,20 +148,17 @@ impl Display for Hash {
 const RSA_SIG_LEN: usize = 512;
 
 /// A cryptographic signature.
-#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
+#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, EnumDiscriminants)]
+#[strum_discriminants(name(SignatureKind))]
 pub(crate) enum Signature {
     #[serde(with = "BigArray")]
     Rsa([u8; RSA_SIG_LEN]),
 }
 
-pub(crate) enum SignatureId {
-    Rsa
-}
-
 impl Signature {
-    fn new(id: SignatureId) -> Signature {
+    fn new(id: SignatureKind) -> Signature {
         match id {
-            SignatureId::Rsa => Signature::Rsa([0; RSA_SIG_LEN])
+            SignatureKind::Rsa => Signature::Rsa([0; RSA_SIG_LEN])
         }
     }
 
@@ -178,6 +175,18 @@ impl Signature {
     }
 }
 
+impl AsRef<[u8]> for Signature {
+    fn as_ref(&self) -> &[u8] {
+        self.as_slice()
+    }
+}
+
+impl AsMut<[u8]> for Signature {
+    fn as_mut(&mut self) -> &mut [u8] {
+        self.as_mut_slice()
+    }
+}
+
 impl Default for Signature {
     fn default() -> Self {
         Signature::Rsa([0; RSA_SIG_LEN])
@@ -228,15 +237,6 @@ pub(crate) enum SymKey {
     },
 }
 
-fn generate_key_and_iv<const KEY_LEN: usize, const IV_LEN: usize>(
-) -> Result<([u8; KEY_LEN], [u8; IV_LEN])> {
-    let mut key = [0; KEY_LEN];
-    let mut iv = [0; IV_LEN];
-    rand_bytes(&mut key).map_err(Error::from)?;
-    rand_bytes(&mut iv).map_err(Error::from)?;
-    Ok((key, iv))
-}
-
 /// Returns an array of the given length filled with cryptographically random data.
 fn rand_array<const LEN: usize>() -> Result<[u8; LEN]> {
     let mut array = [0; LEN];
@@ -283,19 +283,16 @@ impl<T> AsymKey<T> {
 
 impl<T> Owned for AsymKey<T> {
     fn owner_of_kind(&self, kind: HashKind) -> Principal {
-        let slice = match self {
-            AsymKey::Rsa { der, .. } => der.as_slice(),
-        };
         match kind {
             HashKind::Sha2_256 => {
                 let mut buf = [0; 32];
-                let bytes = hash(MessageDigest::sha256(), slice).unwrap();
+                let bytes = hash(MessageDigest::sha256(), self.as_slice()).unwrap();
                 buf.copy_from_slice(&*bytes);
                 Principal(Hash::Sha2_256(buf))
             },
             HashKind::Sha2_512 => {
                 let mut buf = [0; 64];
-                let bytes = hash(MessageDigest::sha512(), slice).unwrap();
+                let bytes = hash(MessageDigest::sha512(), self.as_slice()).unwrap();
                 buf.copy_from_slice(&*bytes);
                 Principal(Hash::Sha2_512(buf))
             }
@@ -395,9 +392,8 @@ impl<'a> TryFrom<&'a AsymKey<Public>> for EncryptionAlgo<'a> {
     fn try_from(key: &'a AsymKey<Public>) -> Result<EncryptionAlgo<'a>> {
         match key {
             AsymKey::Rsa { der, padding, .. } => {
-                let pkey = PKey::public_key_from_der(der.as_slice())
-                    .map_err(|err| Error::Message(err.to_string()));
-                Ok(EncryptionAlgo::Asymmetric { key: pkey?, rsa_padding: Some((*padding).into()) })
+                let key = PKey::public_key_from_der(der.as_slice()).map_err(Error::from)?;
+                Ok(EncryptionAlgo::Asymmetric { key, rsa_padding: Some((*padding).into()) })
             },
         }
     }
@@ -488,7 +484,7 @@ impl TryFrom<&AsymKey<Private>> for SignAlgo {
                 Ok(SignAlgo {
                     key: PKey::from_rsa(rsa).map_err(Error::from)?,
                     digest: MessageDigest::sha256(),
-                    signature: Signature::new(SignatureId::Rsa)
+                    signature: Signature::new(SignatureKind::Rsa)
                 })
             },
         }
@@ -559,7 +555,7 @@ pub(crate) fn decrypt_block(
     Ok(block)
 }
 
-fn encrypt<'a, K: TryInto<EncryptionAlgo<'a>, Error = Error>, T: Serialize>(
+fn encrypt<'a, T: Serialize, K: TryInto<EncryptionAlgo<'a>, Error = Error>>(
     value: &T, key: K
 ) -> Result<Cryptotext<T>> {
     let data = to_vec(value).map_err(Error::from)?;
@@ -567,7 +563,7 @@ fn encrypt<'a, K: TryInto<EncryptionAlgo<'a>, Error = Error>, T: Serialize>(
     Ok(Cryptotext::Cipher(vec?))
 }
 
-fn decrypt<'a, K: TryInto<DecryptionAlgo<'a>, Error = Error>, T: Serialize + DeserializeOwned>(
+fn decrypt<'a, T: Serialize + DeserializeOwned, K: TryInto<DecryptionAlgo<'a>, Error = Error>>(
     cryptotext: Cryptotext<T>, key: K
 ) -> Result<T> {
     let data = match cryptotext {
@@ -638,10 +634,7 @@ pub(crate) fn verify_block(block: &Block) -> Result<bool> {
     let header = to_vec(&sig_header)?;
     let verify_algo = VerifyAlgo::try_from(&block.writecap.signing_key)?;
     let parts = [header.as_slice(), body].into_iter();
-    if !verify_algo.verify(parts, block.signature.as_slice())? {
-        return Ok(false);
-    }
-    Ok(true)
+    Ok(verify_algo.verify(parts, block.signature.as_slice())?)
 }
 
 #[derive(Serialize)]