Browse Source

Used strum to eliminate duplicate enum definition.

Matthew Carr 2 years ago
parent
commit
80eeef4e44
2 changed files with 25 additions and 25 deletions
  1. 23 23
      crates/node/src/crypto.rs
  2. 2 2
      crates/node/src/test_helpers.rs

+ 23 - 23
crates/node/src/crypto.rs

@@ -70,6 +70,7 @@ pub type Result<T> = std::result::Result<T, Error>;
 /// A cryptographic hash.
 #[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Hashable, Clone, EnumDiscriminants)]
 #[strum_discriminants(derive(EnumString, Display))]
+#[strum_discriminants(name(HashType))]
 pub enum Hash {
     Sha2_256([u8; 32]),
     #[serde(with = "BigArray")]
@@ -82,11 +83,11 @@ impl Hash {
     const HASH_SEP: char = ':';
 }
 
-impl From<HashDiscriminants> for Hash {
-    fn from(discriminant: HashDiscriminants) -> Hash {
+impl From<HashType> for Hash {
+    fn from(discriminant: HashType) -> Hash {
         match discriminant {
-            HashDiscriminants::Sha2_512 => Hash::Sha2_512([0u8; 64]),
-            HashDiscriminants::Sha2_256 => Hash::Sha2_256([0u8; 32])
+            HashType::Sha2_512 => Hash::Sha2_512([0u8; 64]),
+            HashType::Sha2_256 => Hash::Sha2_256([0u8; 32])
         }
     }
 }
@@ -118,7 +119,7 @@ impl TryFrom<&str> for Hash {
         };
         let second = split.pop().ok_or(Error::InvalidFormat)?;
         let first = split.pop().ok_or(Error::InvalidFormat)?;
-        let mut hash = Hash::from(HashDiscriminants::from_str(first)
+        let mut hash = Hash::from(HashType::from_str(first)
             .map_err(|_| Error::InvalidFormat)?);
         base64_url::decode_to_slice(second, hash.as_mut())
             .map_err(|_| Error::InvalidFormat)?;
@@ -128,7 +129,7 @@ impl TryFrom<&str> for Hash {
 
 impl Display for Hash {
     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
-        let hash_type: HashDiscriminants = self.into();
+        let hash_type: HashType = self.into();
         let hash_data = base64_url::encode(self.as_ref());
         write!(f, "{}{}{}", hash_type, Hash::HASH_SEP, hash_data)
     }
@@ -173,14 +174,6 @@ impl Default for Signature {
     }
 }
 
-/// Identifies a type of cryptographic key. The variants of this enum match those of `Key`.
-pub enum KeyId {
-    // TODO: Write a macro to generate this from `Key`.
-    Aes256Cbc,
-    Aes256Ctr,
-    Rsa,
-}
-
 #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
 pub enum RsaPadding {
     None,
@@ -189,6 +182,12 @@ pub enum RsaPadding {
     Pkcs1Pss,
 }
 
+impl Default for RsaPadding {
+    fn default() -> Self {
+        RsaPadding::Pkcs1
+    }
+}
+
 impl Copy for RsaPadding {}
 
 impl From<RsaPadding> for OpensslPadding {
@@ -203,7 +202,8 @@ impl From<RsaPadding> for OpensslPadding {
 }
 
 /// A cryptographic key.
-#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
+#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, EnumDiscriminants)]
+#[strum_discriminants(name(KeyType))]
 pub enum Key {
     /// A key for the AES 256 cipher in Cipher Block Chaining mode. Note that this includes the
     /// initialization vector, so that a value of this variant contains all the information needed
@@ -246,23 +246,23 @@ impl Key {
         }
     }
 
-    pub fn generate(id: KeyId) -> Result<Key> {
+    pub fn generate(id: KeyType) -> Result<Key> {
         match id {
-            KeyId::Aes256Cbc => {
+            KeyType::Aes256Cbc => {
                 let mut key = [0; 32];
                 let mut iv = [0; 16];
                 rand_bytes(&mut key).map_err(Error::from)?;
                 rand_bytes(&mut iv).map_err(Error::from)?;
                 Ok(Key::Aes256Cbc { key, iv })
             },
-            KeyId::Aes256Ctr => {
+            KeyType::Aes256Ctr => {
                 let mut key = [0; 32];
                 let mut iv = [0; 16];
                 rand_bytes(&mut key).map_err(Error::from)?;
                 rand_bytes(&mut iv).map_err(Error::from)?;
                 Ok(Key::Aes256Ctr { key, iv })
             },
-            KeyId::Rsa => {
+            KeyType::Rsa => {
                 let key = PKey::from_rsa(Rsa::generate(4096)?)?;
                 let public= key.public_key_to_der().map_err(Error::from)?;
                 let private = Some(key.private_key_to_der().map_err(Error::from)?);
@@ -728,13 +728,13 @@ mod tests {
         let principal = make_principal();
         let block_key = Key::Aes256Ctr { key: BLOCK_KEY, iv: BLOCK_IV };
         let block = make_versioned_block(principal.clone(), block_key)?;
-        let key = Key::generate(KeyId::Rsa)?;
+        let key = Key::generate(KeyType::Rsa)?;
         encrypt_decrypt_block_test_case(block, &principal, &key)
     }
 
     #[test]
     fn rsa_sign_and_verify() -> Result<()> {
-        let key = Key::generate(KeyId::Rsa)?;
+        let key = Key::generate(KeyType::Rsa)?;
         let header = b"About: lyrics".as_slice();
         let message = b"Everything that feels so good is bad bad bad.".as_slice();
         let mut signer = SignAlgo::try_from(&key)?;
@@ -750,7 +750,7 @@ mod tests {
         let principal = make_principal();
         let block_key = Key::Aes256Ctr { key: BLOCK_KEY, iv: BLOCK_IV };
         let mut block = make_versioned_block(principal.clone(), block_key)?;
-        let key = Key::generate(KeyId::Rsa)?;
+        let key = Key::generate(KeyType::Rsa)?;
         let writecap = Writecap {
             issued_to: Principal(Hash::Sha2_256(PRINCIPAL)),
             path: make_path(vec!["contacts", "emergency"]),
@@ -898,7 +898,7 @@ mod tests {
         #[test]
         fn rsa_signature_len() -> Result<()> {
             use openssl::rsa::Rsa;
-            let key = Key::generate(KeyId::Rsa)?;
+            let key = Key::generate(KeyType::Rsa)?;
             let sign_algo = SignAlgo::try_from(&key)?;
             let private = match &key {
                 Key::Rsa { private: Some(private), .. } => private,

+ 2 - 2
crates/node/src/test_helpers.rs

@@ -1,7 +1,7 @@
 /// Test data and functions to help with testing.
 
 use super::*;
-use crypto::{ KeyId, RsaPadding };
+use crypto::{ KeyType, RsaPadding };
 use serde_block_tree::{Error, Result};
 use std::{ fs::File, io::Write, fmt::Write as FmtWrite };
 
@@ -581,7 +581,7 @@ impl<'a> NamedSlice<'a> {
 }
 
 fn write_rsa_keys_to_file(path: &str) -> Result<()> {
-    let key = Key::generate(KeyId::Rsa).map_err(|e| Error::Message(e.to_string()))?;
+    let key = Key::generate(KeyType::Rsa).map_err(|e| Error::Message(e.to_string()))?;
     let (public, private) = match key {
         Key::Rsa { public, private, .. } => (public, private.unwrap()),
         _ => return Err(Error::Message("unexpected key type".to_string())),