Browse Source

Introduced the Scheme trait and associated types
for representing cryptographic schemes.

Matthew Carr 2 years ago
parent
commit
142addac7a

+ 252 - 73
crates/btnode/src/crypto/mod.rs

@@ -20,7 +20,7 @@ use serde::{
     ser::{Serializer, SerializeStruct},
 };
 use std::{
-    str::FromStr, num::TryFromIntError,
+    str::FromStr, num::TryFromIntError, marker::PhantomData,
 };
 use strum_macros::{EnumString, EnumDiscriminants, Display};
 
@@ -35,7 +35,7 @@ pub(crate) enum Cryptotext<T: Serialize> {
 
 /// Errors that can occur during cryptographic operations.
 #[derive(Debug)]
-pub(crate) enum Error {
+pub enum Error {
     NoReadCap,
     NoKeyAvailable,
     MissingPrivateKey,
@@ -108,7 +108,7 @@ impl<T, E: Into<Error>> ConvErr<T> for std::result::Result<T, E> {
 
 /// A cryptographic hash.
 #[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Hashable, Clone, EnumDiscriminants)]
-#[strum_discriminants(derive(EnumString, Display))]
+#[strum_discriminants(derive(EnumString, Display, Serialize, Deserialize))]
 #[strum_discriminants(name(HashKind))]
 pub(crate) enum Hash {
     Sha2_256([u8; 32]),
@@ -122,6 +122,15 @@ impl Default for HashKind {
     }
 }
 
+impl From<HashKind> for MessageDigest {
+    fn from(kind: HashKind) -> Self {
+        match kind {
+            HashKind::Sha2_256 => MessageDigest::sha256(),
+            HashKind::Sha2_512 => MessageDigest::sha512(),
+        }
+    }
+}
+
 impl Hash {
     /// The character that's used to separate a hash type from its value in its string
     /// representation.
@@ -283,59 +292,209 @@ impl SymKey {
     }
 }
 
-#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
-pub(crate) enum AsymKeyKind {
-    Rsa,
+pub trait Scheme: for<'de> Deserialize<'de> + Serialize + Copy + std::fmt::Debug + PartialEq {
+    type Kind;
+    fn kind(self) -> Self::Kind;
+    fn as_enum(self) -> SchemeKind;
+    fn message_digest(&self) -> MessageDigest;
+    fn padding(&self) -> Option<OpensslPadding>;
+    fn public_from_der(self, der: &[u8]) -> Result<PKey<Public>>;
 }
 
-#[derive(Debug, Clone)]
-pub(crate) struct AsymKeyPub {
-    kind: AsymKeyKind,
-    pkey: PKey<Public>,
+pub enum SchemeKind {
+    Sign(Sign),
+    Encrypt(Encrypt),
 }
 
-impl AsymKeyPub {
-    pub(crate) fn new(kind: AsymKeyKind, der: &[u8]) -> Result<AsymKeyPub> {
-        let pkey = match kind {
-            AsymKeyKind::Rsa => {
-                let rsa = Rsa::public_key_from_der(der).conv_err()?;
-                PKey::from_rsa(rsa).conv_err()?
-            }
-        };
-        Ok(AsymKeyPub { kind, pkey })
+#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Copy)]
+pub enum Encrypt {
+    RsaEsOaep(RsaEsOaep),
+}
+
+impl Scheme for Encrypt {
+    type Kind = Encrypt;
+
+    fn kind(self) -> Encrypt {
+        self
     }
 
-    fn digest(&self) -> MessageDigest {
-        match self.kind {
-            AsymKeyKind::Rsa => MessageDigest::sha256(),
+    fn as_enum(self) -> SchemeKind {
+        SchemeKind::Encrypt(self.kind())
+    }
+
+    fn message_digest(&self) -> MessageDigest {
+        match self {
+            Encrypt::RsaEsOaep(inner) => inner.message_digest(),
         }
     }
 
-    fn signature_buf(&self) -> Signature {
-        match self.kind {
-            AsymKeyKind::Rsa => Signature::new(SignatureKind::Rsa),
+    fn padding(&self) -> Option<OpensslPadding> {
+        match self {
+            Encrypt::RsaEsOaep(inner) => inner.padding(),
+        }
+    }
+
+    fn public_from_der(self, der: &[u8]) -> Result<PKey<Public>> {
+        match self {
+            Encrypt::RsaEsOaep(inner) => inner.public_from_der(der),
+        }
+    }
+}
+
+impl Encrypt {
+    pub const RSA_OAEP_3072_SHA_256: Encrypt = Encrypt::RsaEsOaep(RsaEsOaep {
+        key_bytes: 384,
+        hash_kind: HashKind::Sha2_256,
+    });
+}
+
+#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Copy)]
+pub enum Sign {
+    RsaSsaPss(RsaSsaPss),
+}
+
+impl Scheme for Sign {
+    type Kind = Sign;
+
+    fn kind(self) -> Sign {
+        self
+    }
+
+    fn as_enum(self) -> SchemeKind {
+        SchemeKind::Sign(self.kind())
+    }
+
+    fn message_digest(&self) -> MessageDigest {
+        match self {
+            Sign::RsaSsaPss(inner) => inner.message_digest(),
         }
     }
 
     fn padding(&self) -> Option<OpensslPadding> {
-        match self.kind {
-            AsymKeyKind::Rsa => Some(OpensslPadding::PKCS1),
+        match self {
+            Sign::RsaSsaPss(inner) => inner.padding(),
         }
     }
 
+    fn public_from_der(self, der: &[u8]) -> Result<PKey<Public>> {
+        match self {
+            Sign::RsaSsaPss(inner) => inner.public_from_der(der),
+        }
+    }
+}
+
+impl Sign {
+    pub const RSA_PSS_3072_SHA_256: Sign = Sign::RsaSsaPss(RsaSsaPss {
+        key_bytes: 384,
+        hash_kind: HashKind::Sha2_256,
+    });
+
+    fn sig_buf(&self) -> Signature {
+        match self {
+            Sign::RsaSsaPss(_) => Signature::Rsa([0u8; RSA_KEY_BYTES]),
+        }
+    }
+}
+
+#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Copy)]
+pub struct RsaEsOaep {
+    key_bytes: usize,
+    hash_kind: HashKind,
+}
+
+impl Scheme for RsaEsOaep {
+    type Kind = Encrypt;
+
+    fn kind(self) -> Encrypt {
+        Encrypt::RsaEsOaep(self)
+    }
+
+    fn as_enum(self) -> SchemeKind {
+        SchemeKind::Encrypt(self.kind())
+    }
+
+    fn message_digest(&self) -> MessageDigest {
+        self.hash_kind.into()
+    }
+
+    fn padding(&self) -> Option<OpensslPadding> {
+        Some(OpensslPadding::PKCS1_OAEP)
+    }
+
+    fn public_from_der(self, der: &[u8]) -> Result<PKey<Public>> {
+        PKey::public_key_from_der(der).conv_err()
+    }
+}
+
+#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Copy)]
+pub struct RsaSsaPss {
+    key_bytes: usize,
+    hash_kind: HashKind,
+}
+
+impl Scheme for RsaSsaPss {
+    type Kind = Sign;
+
+    fn kind(self) -> Sign {
+        Sign::RsaSsaPss(self)
+    }
+
+    fn as_enum(self) -> SchemeKind {
+        SchemeKind::Sign(self.kind())
+    }
+
+    fn message_digest(&self) -> MessageDigest {
+        self.hash_kind.into()
+    }
+
+    fn padding(&self) -> Option<OpensslPadding> {
+        Some(OpensslPadding::PKCS1_PSS)
+    }
+
+    fn public_from_der(self, der: &[u8]) -> Result<PKey<Public>> {
+        PKey::public_key_from_der(der).conv_err()
+    }
+}
+
+#[derive(Debug, Clone)]
+pub struct AsymKeyPub<S: Scheme> {
+    scheme: S,
+    pkey: PKey<Public>,
+}
+
+impl<S: Scheme> AsymKeyPub<S> {
+    pub(crate) fn new(scheme: S, der: &[u8]) -> Result<AsymKeyPub<S>> {
+        let pkey = scheme.public_from_der(der)?;
+        Ok(AsymKeyPub { scheme, pkey })
+    }
+
+    fn digest(&self) -> MessageDigest {
+        self.scheme.message_digest()
+    }
+
+    fn padding(&self) -> Option<OpensslPadding> {
+        self.scheme.padding()
+    }
+
     fn to_der(&self) -> Result<Vec<u8>> {
         self.pkey.public_key_to_der().conv_err()
     }
 }
 
-impl<'de> Deserialize<'de> for AsymKeyPub {
+impl AsymKeyPub<Sign> {
+    fn signature_buf(&self) -> Signature {
+        self.scheme.sig_buf()
+    }
+}
+
+impl<'de, S: Scheme> Deserialize<'de> for AsymKeyPub<S> {
     fn deserialize<D: Deserializer<'de>>(d: D) -> std::result::Result<Self, D::Error> {
-        const FIELDS: &[&str] = &["kind", "pkey"];
+        const FIELDS: &[&str] = &["scheme", "pkey"];
 
-        struct StructVisitor;
+        struct StructVisitor<S: Scheme>(PhantomData<S>);
 
-        impl<'de> Visitor<'de> for StructVisitor {
-            type Value = AsymKeyPub;
+        impl<'de, S: Scheme> Visitor<'de> for StructVisitor<S> {
+            type Value = AsymKeyPub<S>;
 
             fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
                 formatter.write_fmt(format_args!("struct {}", stringify!(AsymKeyPub)))
@@ -343,37 +502,37 @@ impl<'de> Deserialize<'de> for AsymKeyPub {
 
             fn visit_seq<V: SeqAccess<'de>>(
                 self, mut seq: V
-            ) -> std::result::Result<AsymKeyPub, V::Error> {
-                let kind: AsymKeyKind = seq.next_element()?
+            ) -> std::result::Result<Self::Value, V::Error> {
+                let scheme: S = seq.next_element()?
                     .ok_or_else(|| de::Error::missing_field(FIELDS[0]))?;
                 let der: Vec<u8> = seq.next_element()?
                     .ok_or_else(|| de::Error::missing_field(FIELDS[1]))?;
-                AsymKeyPub::new(kind, der.as_slice())
+                AsymKeyPub::new(scheme, der.as_slice())
                     .map_err(de::Error::custom)
             }
         }
 
-        d.deserialize_struct(stringify!(AsymKeyPub), FIELDS, StructVisitor)
+        d.deserialize_struct(stringify!(AsymKeyPub), FIELDS, StructVisitor(PhantomData))
     }
 }
 
-impl Serialize for AsymKeyPub {
-    fn serialize<S: Serializer>(&self, s: S) -> std::result::Result<S::Ok, S::Error> {
+impl<S: Scheme> Serialize for AsymKeyPub<S> {
+    fn serialize<T: Serializer>(&self, s: T) -> std::result::Result<T::Ok, T::Error> {
         let mut struct_s = s.serialize_struct(stringify!(AsymKeyPub), 2)?;
-        struct_s.serialize_field("kind", &self.kind)?;
+        struct_s.serialize_field("scheme", &self.scheme)?;
         let der = self.pkey.public_key_to_der().unwrap();
         struct_s.serialize_field("pkey", der.as_slice())?;
         struct_s.end()
     }
 }
 
-impl PartialEq for AsymKeyPub {
+impl<S: Scheme> PartialEq for AsymKeyPub<S> {
     fn eq(&self, other: &Self) -> bool {
-        self.kind == other.kind && self.pkey.public_eq(&other.pkey)
+        self.scheme == other.scheme && self.pkey.public_eq(&other.pkey)
     }
 }
 
-impl Owned for AsymKeyPub {
+impl Owned for AsymKeyPub<Sign> {
     fn owner_of_kind(&self, kind: HashKind) -> Principal {
         match kind {
             HashKind::Sha2_256 => {
@@ -394,17 +553,15 @@ impl Owned for AsymKeyPub {
     }
 }
 
-impl CredsPub for AsymKeyPub {}
-
-pub(crate) struct RsaPriv {
+pub(crate) struct ConcreteCredsPriv {
     pkey: PKey<Private>,
 }
 
-impl RsaPriv {
-    pub(crate) fn new(der: &[u8]) -> Result<RsaPriv> {
+impl ConcreteCredsPriv {
+    pub(crate) fn new(der: &[u8]) -> Result<ConcreteCredsPriv> {
         let rsa = Rsa::private_key_from_der(der).conv_err()?;
         let pkey = PKey::from_rsa(rsa).conv_err()?;
-        Ok(RsaPriv { pkey })
+        Ok(ConcreteCredsPriv { pkey })
     }
 
     fn digest() -> MessageDigest {
@@ -416,7 +573,7 @@ impl RsaPriv {
     }
 }
 
-impl Decrypter for RsaPriv {
+impl Decrypter for ConcreteCredsPriv {
     fn decrypt(&self, slice: &[u8]) -> Result<Vec<u8>> {
         let decrypter = OsslDecrypter::new(&self.pkey).conv_err()?;
         let buffer_len = decrypter.decrypt_len(slice).conv_err()?;
@@ -427,10 +584,10 @@ impl Decrypter for RsaPriv {
     }
 }
 
-impl Signer for RsaPriv {
+impl Signer for ConcreteCredsPriv {
     fn sign<'a, I: Iterator<Item=&'a [u8]>>(&self, parts: I) -> Result<Signature> {
-        let digest = RsaPriv::digest();
-        let mut signature = RsaPriv::signature_buf();
+        let digest = ConcreteCredsPriv::digest();
+        let mut signature = ConcreteCredsPriv::signature_buf();
 
         let mut signer = OsslSigner::new(digest, &self.pkey).conv_err()?;
         for part in parts {
@@ -442,30 +599,58 @@ impl Signer for RsaPriv {
     }
 }
 
-impl CredsPriv for RsaPriv {}
+impl CredsPriv for ConcreteCredsPriv {}
+
+pub struct ConcreteCredsPub {
+    pub encrypt: AsymKeyPub<Encrypt>,
+    pub sign: AsymKeyPub<Sign>,
+}
+
+impl Verifier for ConcreteCredsPub {
+    fn verify<'a, I: Iterator<Item=&'a [u8]>>(&self, parts: I, signature: &[u8]) -> Result<bool> {
+        self.sign.verify(parts, signature)
+    }
+}
+
+impl Encrypter for ConcreteCredsPub {
+    fn encrypt(&self, slice: &[u8]) -> Result<Vec<u8>> {
+        self.encrypt.encrypt(slice)
+    }
+}
+
+impl Owned for ConcreteCredsPub {
+    fn owner_of_kind(&self, kind: HashKind) -> Principal {
+        self.sign.owner_of_kind(kind)
+    }
+}
+
+impl CredsPub for ConcreteCredsPub {}
 
 pub(crate) struct ConcreteCreds<T: CredsPriv> {
-    public: AsymKeyPub,
+    public: ConcreteCredsPub,
     private: T,
 }
 
 impl<T: CredsPriv> ConcreteCreds<T> {
-    pub(crate) fn new(public: AsymKeyPub, private: T) -> ConcreteCreds<T> {
+    pub(crate) fn new(public: ConcreteCredsPub, private: T) -> ConcreteCreds<T> {
         ConcreteCreds { public, private }
     }
 }
 
 
-impl ConcreteCreds<RsaPriv> {
-    pub(crate) fn generate() -> Result<ConcreteCreds<RsaPriv>> {
+impl ConcreteCreds<ConcreteCredsPriv> {
+    pub(crate) fn generate() -> Result<ConcreteCreds<ConcreteCredsPriv>> {
         let key_bits = 8 * u32::try_from(RSA_KEY_BYTES).conv_err()?;
         let key = Rsa::generate(key_bits)?;
         // TODO: Separating the keys this way seems inefficient. Investigate alternatives.
         let public_der = key.public_key_to_der().conv_err()?;
         let private_der = key.private_key_to_der().conv_err()?;
         Ok(ConcreteCreds {
-            public: AsymKeyPub::new(AsymKeyKind::Rsa, public_der.as_slice())?,
-            private: RsaPriv::new(private_der.as_slice())?,
+            public: ConcreteCredsPub {
+                encrypt: AsymKeyPub::new(Encrypt::RSA_OAEP_3072_SHA_256, public_der.as_slice())?,
+                sign: AsymKeyPub::new(Sign::RSA_PSS_3072_SHA_256, public_der.as_slice())?,
+            },
+            private: ConcreteCredsPriv::new(private_der.as_slice())?,
         })
     }
 }
@@ -505,8 +690,8 @@ impl<T: CredsPriv> Decrypter for ConcreteCreds<T> {
 impl<T: CredsPriv> CredsPriv for ConcreteCreds<T> {}
 
 impl<T: CredsPriv> Creds for ConcreteCreds<T> {
-    fn public(&self) -> &AsymKeyPub {
-        &self.public
+    fn public(&self) -> &AsymKeyPub<Sign> {
+        &self.public.sign
     }
 }
 
@@ -521,7 +706,7 @@ impl Encrypter for SymKey {
     }
 }
 
-impl Encrypter for AsymKeyPub {
+impl Encrypter for AsymKeyPub<Encrypt> {
     fn encrypt(&self, slice: &[u8]) -> Result<Vec<u8>> {
         let mut encrypter = OsslEncrypter::new(&self.pkey).conv_err()?;
         if let Some(padding) = self.padding() {
@@ -555,7 +740,7 @@ pub(crate) trait Verifier {
     fn verify<'a, I: Iterator<Item=&'a [u8]>>(&self, parts: I, signature: &[u8]) -> Result<bool>;
 }
 
-impl Verifier for AsymKeyPub {
+impl Verifier for AsymKeyPub<Sign> {
     fn verify<'a, I: Iterator<Item=&'a [u8]>>(&self, parts: I, signature: &[u8]) -> Result<bool> {
         let mut verifier = OsslVerifier::new(self.digest(), &self.pkey).conv_err()?;
         for part in parts {
@@ -573,7 +758,7 @@ pub(crate) trait CredsPriv: Decrypter + Signer {}
 
 /// Trait for types which contain both public and private credentials.
 pub(crate) trait Creds: CredsPriv + CredsPub {
-    fn public(&self) -> &AsymKeyPub;
+    fn public(&self) -> &AsymKeyPub<Sign>;
 }
 
 /// A trait for types which store credentials.
@@ -692,7 +877,7 @@ struct WritecapSigInput<'a> {
     issued_to: &'a Principal,
     path: &'a Path,
     expires: &'a Epoch,
-    signing_key: &'a AsymKeyPub,
+    signing_key: &'a AsymKeyPub<Sign>,
 }
 
 impl<'a> From<&'a Writecap> for WritecapSigInput<'a> {
@@ -907,10 +1092,7 @@ mod tests {
         let (mut root_writecap, root_key) = make_self_signed_writecap()?;
         root_writecap.issued_to = Principal(Hash::Sha2_256([0; 32]));
         sign_writecap(&mut root_writecap, &root_key)?;
-        let node_key = AsymKeyPub {
-            kind: AsymKeyKind::Rsa,
-            pkey: PKey::from_rsa(Rsa::public_key_from_der(NODE_PUBLIC_KEY.as_slice())?)?
-        };
+        let node_key = AsymKeyPub::new(Sign::RSA_PSS_3072_SHA_256, NODE_PUBLIC_KEY.as_slice())?;
         let node_principal = node_key.owner();
         let writecap = make_writecap_trusted_by(
             root_writecap, root_key, node_principal, vec!["apps", "contacts"])?;
@@ -925,10 +1107,7 @@ mod tests {
         let owner = Principal(Hash::Sha2_256([0; 32]));
         root_writecap.path = make_path_with_owner(owner, vec![]);
         sign_writecap(&mut root_writecap, &root_key)?;
-        let node_key = AsymKeyPub {
-            kind: AsymKeyKind::Rsa,
-            pkey: PKey::from_rsa(Rsa::public_key_from_der(NODE_PUBLIC_KEY.as_slice())?)?
-        };
+        let node_key = AsymKeyPub::new(Sign::RSA_PSS_3072_SHA_256, NODE_PUBLIC_KEY.as_slice())?;
         let node_owner = node_key.owner();
         let writecap = make_writecap_trusted_by(
             root_writecap, root_key, node_owner, vec!["apps", "contacts"])?;

+ 125 - 123
crates/btnode/src/crypto/tpm.rs

@@ -281,26 +281,55 @@ impl Cookie {
     }
 }
 
-#[derive(Serialize, Deserialize, Clone)]
-struct StoredKeyPair {
-    public: AsymKeyPub,
+#[derive(Serialize, Clone)]
+struct StoredKeyPair<S: Scheme> {
+    public: AsymKeyPub<S>,
     private: TPM2_HANDLE,
 }
 
+impl<'de, S: Scheme> Deserialize<'de> for StoredKeyPair<S> {
+    fn deserialize<D: Deserializer<'de>>(d: D) -> std::result::Result<Self, D::Error> {
+        const FIELDS: &[&str] = &["public", "private"];
+
+        struct StructVisitor<S: Scheme>(PhantomData<S>);
+
+        impl<'de, S: Scheme> Visitor<'de> for StructVisitor<S> {
+            type Value = StoredKeyPair<S>;
+
+            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
+                formatter.write_fmt(format_args!("struct {}", stringify!(StoredKeyPair)))
+            }
+
+            fn visit_seq<A: SeqAccess<'de>>(
+                self, mut seq: A
+            ) -> std::result::Result<Self::Value, A::Error> {
+                let public : AsymKeyPub<S> = seq.next_element()?
+                    .ok_or_else(|| de::Error::missing_field(FIELDS[0]))?;
+                let private: TPM2_HANDLE = seq.next_element()?
+                    .ok_or_else(|| de::Error::missing_field(FIELDS[1]))?;
+                let pair = StoredKeyPair { public, private };
+                Ok(pair)
+            }
+        }
+
+        d.deserialize_struct(stringify!(StoredKeyPair), FIELDS, StructVisitor(PhantomData))
+    }
+}
+
 #[derive(Serialize, Deserialize, Clone)]
 struct TpmHandles {
-    sign: StoredKeyPair,
-    enc: StoredKeyPair,
+    sign: StoredKeyPair<Sign>,
+    enc: StoredKeyPair<Encrypt>,
 }
 
 impl TpmHandles {
-    fn new(sign: StoredKeyPair, enc: StoredKeyPair) -> TpmHandles {
+    fn new(sign: StoredKeyPair<Sign>, enc: StoredKeyPair<Encrypt>) -> TpmHandles {
         TpmHandles { sign, enc }
     }
 
     fn to_key_handles(&self, context: &mut Context) -> Result<KeyHandles> {
-        let sign = SignKey(KeyPair::from_stored(context, &self.sign)?);
-        let enc = EncKey(KeyPair::from_stored(context, &self.enc)?);
+        let sign = KeyPair::from_stored(context, &self.sign)?;
+        let enc = KeyPair::from_stored(context, &self.enc)?;
         Ok(KeyHandles {sign, enc})
     }
 }
@@ -384,58 +413,57 @@ impl KeyKind {
     }
 }
 
-struct KeyParams<'a> {
-    kind: KeyKind,
+impl TryInto<RsaScheme> for SchemeKind {
+    type Error = Error;
+    fn try_into(self) -> Result<RsaScheme> {
+        match self {
+            SchemeKind::Sign(sign) => match sign {
+                Sign::RsaSsaPss(_)
+                    => Ok(RsaScheme::RsaPss(HashScheme::new(HashingAlgorithm::Sha256))),
+            },
+            SchemeKind::Encrypt(encrypt) => match encrypt {
+                Encrypt::RsaEsOaep(_)
+                    => Ok(RsaScheme::Oaep(HashScheme::new(HashingAlgorithm::Sha256))),
+            }
+        }
+    }
+}
+
+struct KeyBuilder<'a, S: Scheme> {
+    scheme: S,
     allow_dup: bool,
-    scheme: RsaScheme,
     unique: &'a [u8],
     auth: Option<Auth>,
 }
 
-impl<'a> KeyParams<'a> {
+impl<'a, S: Scheme> KeyBuilder<'a, S> {
     /// The public exponent to use for generated RSA keys.
     const RSA_EXPONENT: u32 = 65537; // 2**16 + 1
 
     const RSA_KEY_BITS: RsaKeyBits = RsaKeyBits::Rsa3072;
 
-    fn with_unique(unique: &'a [u8]) -> KeyParams<'a> {
-        KeyParams {
-            kind: KeyKind::Sign,
+    fn new(scheme: S, unique: &'a [u8]) -> KeyBuilder<'a, S> {
+        KeyBuilder {
+            scheme,
             allow_dup: false,
-            scheme: RsaScheme::Null,
             unique,
             auth: None,
         }
     }
 
-    fn with_kind(mut self, kind: KeyKind) -> Self {
-        self.kind = kind;
-        self
-    }
-
-    fn with_allow_dup(mut self, allow_dup: bool) -> Self {
-        self.allow_dup = allow_dup;
-        self
-    }
-
-    fn with_scheme(mut self, scheme: RsaScheme) -> Self {
-        self.scheme = scheme;
-        self
-    }
-
-    fn with_auth(mut self, auth: Auth) -> Self {
-        self.auth = Some(auth);
-        self
-    }
-
     fn template(&self) -> Result<Public> {
+        let scheme_enum = self.scheme.as_enum();
+        let decrypt = match scheme_enum {
+            SchemeKind::Sign(_) => false,
+            SchemeKind::Encrypt(_) => true,
+        };
         let object_attributes = ObjectAttributes::builder()
             .with_fixed_tpm(!self.allow_dup)
             .with_fixed_parent(!self.allow_dup)
             .with_sensitive_data_origin(true)
             .with_user_with_auth(true)
-            .with_decrypt(self.kind.decrypt())
-            .with_sign_encrypt(self.kind.sign())
+            .with_decrypt(decrypt)
+            .with_sign_encrypt(!decrypt)
             .with_restricted(false)
             .build()
             .conv_err()?;
@@ -443,7 +471,7 @@ impl<'a> KeyParams<'a> {
         let auth_policy = Digest::empty();
         let parameters = PublicRsaParameters::new(
             SymmetricDefinitionObject::Null,
-            self.scheme,
+            scheme_enum.try_into()?,
             Self::RSA_KEY_BITS,
             RsaExponent::try_from(Self::RSA_EXPONENT).conv_err()?, 
         );
@@ -458,57 +486,41 @@ impl<'a> KeyParams<'a> {
         };
         Ok(public)
     }
+
+    fn with_allow_dup(mut self, allow_dup: bool) -> Self {
+        self.allow_dup = allow_dup;
+        self
+    }
+
+    fn with_auth(mut self, auth: Auth) -> Self {
+        self.auth = Some(auth);
+        self
+    }
 }
 
 #[derive(Clone)]
-struct KeyPair {
-    public: AsymKeyPub,
+struct KeyPair<S: Scheme> {
+    public: AsymKeyPub<S>,
     /// A rust struct which wraps an `ESYS_TR` from the ESAPI.
     private: KeyHandle
 }
 
-impl  KeyPair {
-    fn from_stored(context: &mut Context, stored: &StoredKeyPair) -> Result<KeyPair> {
+impl<S: Scheme> KeyPair<S> {
+    fn from_stored(context: &mut Context, stored: &StoredKeyPair<S>) -> Result<KeyPair<S>> {
         let tpm_handle = TpmHandle::try_from(stored.private).conv_err()?;
         let key_handle = context.key_handle(tpm_handle)?;
         Ok(KeyPair { public: stored.public.clone(), private: key_handle })
     }
 
-    fn to_stored(&self, private: TPM2_HANDLE) -> StoredKeyPair {
+    fn to_stored(&self, private: TPM2_HANDLE) -> StoredKeyPair<S> {
         let public = self.public.clone();
         StoredKeyPair { public, private }
     }
 }
 
-#[derive(Clone)]
-struct SignKey(KeyPair);
-
-impl SignKey {
-    fn public(&self) -> &AsymKeyPub {
-        &self.0.public
-    }
-
-    fn private(&self) -> KeyHandle {
-        self.0.private
-    }
-}
-
-#[derive(Clone)]
-struct EncKey(KeyPair);
-
-impl EncKey {
-    fn public(&self) -> &AsymKeyPub {
-        &self.0.public
-    }
-
-    fn private(&self) -> KeyHandle {
-        self.0.private
-    }
-}
-
 struct KeyHandles {
-    sign: SignKey,
-    enc: EncKey,
+    sign: KeyPair<Sign>,
+    enc: KeyPair<Encrypt>,
 }
 
 struct State {
@@ -529,8 +541,8 @@ impl State {
         };
         let key_handles = tpm_handles.to_key_handles(&mut self.context)?;
         let auth = self.storage.cookie.auth();
-        self.context.tr_set_auth(key_handles.enc.private().into(), auth.clone())?;
-        self.context.tr_set_auth(key_handles.sign.private().into(), auth)?;
+        self.context.tr_set_auth(key_handles.enc.private.into(), auth.clone())?;
+        self.context.tr_set_auth(key_handles.sign.private.into(), auth)?;
         self.node_creds = Some(TpmCreds::new(key_handles, state));
         Ok(())
     }
@@ -564,7 +576,7 @@ impl TpmCredStore {
         Ok(())
     }
 
-    fn gen_key(&self, params: KeyParams) -> Result<KeyPair> {
+    fn gen_key<S: Scheme>(&self, params: KeyBuilder<S>) -> Result<KeyPair<S>> {
         let result = {
             let mut guard = self.state.write().conv_err()?;
             guard.context.create_primary(
@@ -577,52 +589,48 @@ impl TpmCredStore {
             )
             .conv_err()?
         };
-        let public = AsymKeyPub::try_from(result.out_public)?;
+        let public = AsymKeyPub::try_from(result.out_public, params.scheme)?;
         Ok(KeyPair { public, private: result.key_handle })
     }
 
-    fn gen_node_sign_key(&self) -> Result<SignKey> {
-        let params = KeyParams::with_unique(self.cookie.as_slice())
+    fn gen_node_sign_key(&self) -> Result<KeyPair<Sign>> {
+        let params = KeyBuilder::new(Sign::RSA_PSS_3072_SHA_256, self.cookie.as_slice())
             .with_allow_dup(false)
-            .with_kind(KeyKind::Sign)
-            .with_scheme(RsaScheme::Null)
             .with_auth(self.cookie.auth());
-        Ok(SignKey(self.gen_key(params)?))
+        self.gen_key(params)
     }
 
-    fn gen_node_enc_key(&self) -> Result<EncKey> {
-        let params = KeyParams::with_unique(self.cookie.as_slice())
+    fn gen_node_enc_key(&self) -> Result<KeyPair<Encrypt>> {
+        let params = KeyBuilder::new(Encrypt::RSA_OAEP_3072_SHA_256, self.cookie.as_slice())
             .with_allow_dup(false)
-            .with_kind(KeyKind::Decrypt)
-            .with_scheme(RsaScheme::RsaEs)
             .with_auth(self.cookie.auth());
-        Ok(EncKey(self.gen_key(params)?))
+        self.gen_key(params)
     }
 
     fn persist<F: FnOnce(&mut Storage, TpmHandles)>(
         &self, creds: &TpmCreds, update_storage: F
     ) -> Result<()> {
         let mut guard = self.state.write().conv_err()?;
-        let sign_handle = guard.context.persist_key(creds.sign.private())?;
-        let enc_handle = match guard.context.persist_key(creds.enc.private()) {
+        let sign_handle = guard.context.persist_key(creds.sign.private)?;
+        let enc_handle = match guard.context.persist_key(creds.enc.private) {
             Ok(handle) => handle,
             Err(error) => {
-                guard.context.evict_key(sign_handle, Some(creds.sign.private()))?;
+                guard.context.evict_key(sign_handle, Some(creds.sign.private))?;
                 return Err(error)
             }
         };
         let handles = TpmHandles::new(
-            creds.sign.0.to_stored(sign_handle),
-            creds.enc.0.to_stored(enc_handle));
+            creds.sign.to_stored(sign_handle),
+            creds.enc.to_stored(enc_handle));
         update_storage(&mut guard.storage, handles);
         match self.save_storage(&mut guard) {
             Ok(_) => Ok(()),
             Err(error) => {
-                let result = guard.context.evict_key(sign_handle, Some(creds.sign.private()));
+                let result = guard.context.evict_key(sign_handle, Some(creds.sign.private));
                 if let Err(error) = result {
                     error!("failed to evict signing key due to error: {:?}", error)
                 }
-                let result = guard.context.evict_key(enc_handle, Some(creds.enc.private()));
+                let result = guard.context.evict_key(enc_handle, Some(creds.enc.private));
                 if let Err(error) = result {
                     error!("failed to evict encryption key due to error: {:?}", error)
                 }
@@ -639,24 +647,20 @@ impl TpmCredStore {
         Ok(creds)
     }
 
-    fn gen_root_sign_key(&self, password: &str) -> Result<SignKey> {
+    fn gen_root_sign_key(&self, password: &str) -> Result<KeyPair<Sign>> {
         let unique: [u8; COOKIE_LEN] = rand_array()?;
-        let params = KeyParams::with_unique(unique.as_slice())
+        let params = KeyBuilder::new(Sign::RSA_PSS_3072_SHA_256, unique.as_slice())
             .with_allow_dup(true)
-            .with_kind(KeyKind::Sign)
-            .with_scheme(RsaScheme::Null)
             .with_auth(Auth::try_from(password.as_bytes()).conv_err()?);
-        Ok(SignKey(self.gen_key(params)?))
+        self.gen_key(params)
     }
 
-    fn gen_root_enc_key(&self, password: &str) -> Result<EncKey> {
+    fn gen_root_enc_key(&self, password: &str) -> Result<KeyPair<Encrypt>> {
         let unique: [u8; COOKIE_LEN] = rand_array()?;
-        let params = KeyParams::with_unique(unique.as_slice())
+        let params = KeyBuilder::new(Encrypt::RSA_OAEP_3072_SHA_256, unique.as_slice())
             .with_allow_dup(true)
-            .with_kind(KeyKind::Decrypt)
-            .with_scheme(RsaScheme::RsaEs)
             .with_auth(Auth::try_from(password.as_bytes()).conv_err()?);
-        Ok(EncKey(self.gen_key(params)?))
+        self.gen_key(params)
     }
 }
 
@@ -683,8 +687,8 @@ impl CredStore for TpmCredStore {
         let mut guard = self.state.write().conv_err()?;
         let key_handles = root_handles.to_key_handles(&mut guard.context)?;
         let auth = Auth::try_from(password.as_bytes()).conv_err()?;
-        guard.context.tr_set_auth(key_handles.sign.private().into(), auth.clone())?;
-        guard.context.tr_set_auth(key_handles.enc.private().into(), auth)?;
+        guard.context.tr_set_auth(key_handles.sign.private.into(), auth.clone())?;
+        guard.context.tr_set_auth(key_handles.enc.private.into(), auth)?;
         Ok(TpmCreds::new(key_handles, &self.state))
     }
 
@@ -703,10 +707,8 @@ impl CredStore for TpmCredStore {
     }
 }
 
-impl TryFrom<Public> for AsymKeyPub {
-    type Error = Error;
-
-    fn try_from(public: Public) -> Result<AsymKeyPub> {
+impl<S: Scheme> AsymKeyPub<S> {
+    fn try_from(public: Public, scheme: S) -> Result<AsymKeyPub<S>> {
         match public {
             Public::Rsa { parameters, unique, .. } => {
                 let exponent_value = parameters.exponent().value();
@@ -714,7 +716,7 @@ impl TryFrom<Public> for AsymKeyPub {
                 let modulus = BigNum::from_slice(unique.as_slice()).conv_err()?;
                 let rsa = Rsa::from_public_components(modulus, exponent).conv_err()?;
                 let pkey = PKey::from_rsa(rsa).conv_err()?;
-                Ok(AsymKeyPub { pkey, kind: AsymKeyKind::Rsa })
+                Ok(AsymKeyPub { pkey, scheme })
             },
             _ => Err(Error::custom("Unsupported key type returned by TPM")),
         }
@@ -797,8 +799,8 @@ impl HashcheckTicketExt for HashcheckTicket {
 #[derive(Clone)]
 pub(crate) struct TpmCreds {
     state: Arc<RwLock<State>>,
-    sign: SignKey,
-    enc: EncKey,
+    sign: KeyPair<Sign>,
+    enc: KeyPair<Encrypt>,
 }
 
 impl TpmCreds {
@@ -811,8 +813,8 @@ impl Owned for TpmCreds {
     fn owner_of_kind(&self, kind: HashKind) -> Principal {
         fn hash(creds: &TpmCreds, msg_digest: MessageDigest, mut buf: &mut [u8]) {
             let digest = {
-                let sign_der = creds.sign.public().to_der().unwrap();
-                let enc_der = creds.enc.public().to_der().unwrap();
+                let sign_der = creds.sign.public.to_der().unwrap();
+                let enc_der = creds.enc.public.to_der().unwrap();
                 let mut hasher = Hasher::new(msg_digest).unwrap();
                 hasher.update(sign_der.as_slice()).unwrap();
                 hasher.update(enc_der.as_slice()).unwrap();
@@ -837,13 +839,13 @@ impl Owned for TpmCreds {
 
 impl Verifier for TpmCreds {
     fn verify<'a, I: Iterator<Item=&'a [u8]>>(&self, parts: I, signature: &[u8]) -> Result<bool> {
-        self.sign.public().verify(parts, signature)
+        self.sign.public.verify(parts, signature)
     }
 }
 
 impl Encrypter for TpmCreds {
     fn encrypt(&self, slice: &[u8]) -> Result<Vec<u8>> {
-        self.enc.public().encrypt(slice)
+        self.enc.public.encrypt(slice)
     }
 }
 
@@ -851,7 +853,7 @@ impl CredsPub for TpmCreds {}
 
 impl Signer for TpmCreds {
     fn sign<'a, I: Iterator<Item=&'a [u8]>>(&self, parts: I) -> Result<Signature> {
-        let msg_digest = self.sign.public().digest();
+        let msg_digest = self.sign.public.digest();
         let digest = {
             let mut hasher = Hasher::new(msg_digest).conv_err()?;
             for part in parts {
@@ -865,7 +867,7 @@ impl Signer for TpmCreds {
         let scheme = SignatureScheme::RsaSsa { hash_scheme: msg_digest.hash_scheme()? };
         let sig = {
             let mut guard = self.state.write().conv_err()?;
-            guard.context.sign(self.sign.private(), digest, scheme, validation)
+            guard.context.sign(self.sign.private, digest, scheme, validation)
                 .conv_err()?
         };
         let buf = match sig {
@@ -890,7 +892,7 @@ impl Decrypter for TpmCreds {
         let label = Data::try_from(empty.as_slice()).conv_err()?;
         let plain_text = {
             let mut guard = self.state.write().conv_err()?;
-            guard.context.rsa_decrypt(self.enc.private(), cipher_text, in_scheme, label)?
+            guard.context.rsa_decrypt(self.enc.private, cipher_text, in_scheme, label)?
         };
         Ok(Vec::from(plain_text.value()))
     } 
@@ -899,8 +901,8 @@ impl Decrypter for TpmCreds {
 impl CredsPriv for TpmCreds {}
 
 impl Creds for TpmCreds {
-    fn public(&self) -> &AsymKeyPub {
-        unimplemented!()
+    fn public(&self) -> &AsymKeyPub<Sign> {
+        &self.sign.public
     }
 }
 
@@ -1272,7 +1274,7 @@ active_pcr_banks = sha256
     /// Checks that the value of `TpmCredStore::RSA_KEY_BITS` matches the value of `RSA_KEY_BYTES`.
     #[test]
     fn rsa_key_bits_and_key_bytes_compatible() {
-        let bytes = match KeyParams::RSA_KEY_BITS {
+        let bytes = match KeyBuilder::<Sign>::RSA_KEY_BITS {
             RsaKeyBits::Rsa1024 => 128,
             RsaKeyBits::Rsa2048 => 256,
             RsaKeyBits::Rsa3072 => 384,
@@ -1300,7 +1302,7 @@ active_pcr_banks = sha256
     fn persist_key() -> Result<()> {
         let (_harness, store) = test_store()?;
         let cookie = Cookie::random()?;
-        let params = KeyParams::with_unique(cookie.as_slice());
+        let params = KeyBuilder::new(Sign::RSA_PSS_3072_SHA_256, cookie.as_slice());
         let pair = store.gen_key(params)?;
         let mut guard = store.state.write().conv_err()?;
         guard.context.persist_key(pair.private)?;

+ 2 - 2
crates/btnode/src/main.rs

@@ -11,7 +11,7 @@ mod serde_tests;
 use serde_block_tree::{self, read_from, write_to};
 use harness::Message;
 mod crypto;
-use crypto::{Hash, HashKind, Signature, SymKey, AsymKeyPub, Cryptotext};
+use crypto::{Hash, HashKind, Signature, SymKey, AsymKeyPub, Cryptotext, Sign};
 
 use std::{
     collections::HashMap,
@@ -78,7 +78,7 @@ struct Writecap {
     /// The point in time after which this write cap is no longer valid.
     expires: Epoch,
     /// The public key used to sign this write cap.
-    signing_key: AsymKeyPub,
+    signing_key: AsymKeyPub<Sign>,
     /// A digital signature which covers all of the fields in the write cap except for next.
     signature: Signature,
     /// The next write cap in the chain leading back to the root.

+ 5 - 2
crates/btnode/src/test_helpers.rs

@@ -413,8 +413,11 @@ pub(crate) fn make_writecap_trusted_by<C: Creds>(
 }
 
 pub(crate) fn make_key_pair() -> impl Creds {
-    let public = AsymKeyPub::new(AsymKeyKind::Rsa, ROOT_PUBLIC_KEY.as_slice()).unwrap();
-    let private = RsaPriv::new(ROOT_PRIVATE_KEY.as_slice()).unwrap();
+    let public = ConcreteCredsPub {
+        sign: AsymKeyPub::new(Sign::RSA_PSS_3072_SHA_256, ROOT_PUBLIC_KEY.as_slice()).unwrap(),
+        encrypt : AsymKeyPub::new(Encrypt::RSA_OAEP_3072_SHA_256, ROOT_PUBLIC_KEY.as_slice()).unwrap(),
+    };
+    let private = ConcreteCredsPriv::new(ROOT_PRIVATE_KEY.as_slice()).unwrap();
     ConcreteCreds::new(public, private)
 }